19 explicit XSpode(
int spIndex);
20 std::vector<double> predict_proba(
const std::vector<int>& instance)
const;
21 std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X)
override;
22 int predict(
const std::vector<int>& instance)
const;
23 void normalize(std::vector<double>& v)
const;
24 std::string to_string()
const;
25 int getNFeatures()
const;
26 int getNumberOfNodes()
const override;
27 int getNumberOfEdges()
const override;
28 int getNumberOfStates()
const override;
29 int getClassNumStates()
const override;
30 std::vector<int>& getStates();
31 std::vector<std::string> graph(
const std::string& title)
const override {
return std::vector<std::string>({ title }); }
32 void fitx(torch::Tensor& X, torch::Tensor& y, torch::Tensor& weights_,
const Smoothing_t smoothing);
33 void setHyperparameters(
const nlohmann::json& hyperparameters_)
override;
38 torch::Tensor predict(torch::Tensor& X)
override;
39 std::vector<int> predict(std::vector<std::vector<int>>& X)
override;
40 torch::Tensor predict_proba(torch::Tensor& X)
override;
41 float score(torch::Tensor& X, torch::Tensor& y)
override;
42 float score(std::vector<std::vector<int>>& X, std::vector<int>& y)
override;
44 void buildModel(
const torch::Tensor& weights)
override;
45 void trainModel(
const torch::Tensor& weights,
const bayesnet::Smoothing_t smoothing)
override;
47 void addSample(
const std::vector<int>& instance,
double weight);
48 void computeProbabilities();
52 std::vector<int> states_;
55 std::vector<double> classCounts_;
56 std::vector<double> classPriors_;
59 std::vector<double> spFeatureCounts_;
60 std::vector<double> spFeatureProbs_;
66 std::vector<double> childCounts_;
67 std::vector<double> childProbs_;
68 std::vector<int> childOffsets_;
72 CountingSemaphore& semaphore_;