22 torch::Tensor& getSamples();
23 void addNode(
const std::string&);
24 void addEdge(
const std::string&,
const std::string&);
25 std::map<std::string, std::unique_ptr<Node>>& getNodes();
26 std::vector<std::string> getFeatures()
const;
27 int getStates()
const;
28 std::vector<std::pair<std::string, std::string>> getEdges()
const;
29 int getNumEdges()
const;
30 int getClassNumStates()
const;
31 std::string getClassName()
const;
35 void fit(
const std::vector<std::vector<int>>& input_data,
const std::vector<int>& labels,
const std::vector<double>& weights,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states,
const Smoothing_t smoothing);
36 void fit(
const torch::Tensor& X,
const torch::Tensor& y,
const torch::Tensor& weights,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states,
const Smoothing_t smoothing);
37 void fit(
const torch::Tensor& samples,
const torch::Tensor& weights,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states,
const Smoothing_t smoothing);
38 std::vector<int> predict(
const std::vector<std::vector<int>>&);
39 torch::Tensor predict(
const torch::Tensor&);
40 torch::Tensor predict_tensor(
const torch::Tensor& samples,
const bool proba);
41 std::vector<std::vector<double>> predict_proba(
const std::vector<std::vector<int>>&);
42 torch::Tensor predict_proba(
const torch::Tensor&);
43 double score(
const std::vector<std::vector<int>>&,
const std::vector<int>&);
44 std::vector<std::string> topological_sort();
45 std::vector<std::string> show()
const;
46 std::vector<std::string> graph(
const std::string& title)
const;
48 std::string dump_cpt()
const;
49 inline std::string version() {
return { project_version.begin(), project_version.end() }; }
51 std::map<std::string, std::unique_ptr<Node>> nodes;
54 std::vector<std::string> features;
55 std::string className;
56 torch::Tensor samples;
57 bool isCyclic(
const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
58 std::vector<double> predict_sample(
const std::vector<int>&);
59 std::vector<double> predict_sample(
const torch::Tensor&);
60 std::vector<double> exactInference(std::map<std::string, int>&);
61 void completeFit(
const std::map<std::string, std::vector<int>>& states,
const torch::Tensor& weights,
const Smoothing_t smoothing);
62 void checkFitData(
int n_samples,
int n_features,
int n_samples_y,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states,
const torch::Tensor& weights);
63 void setStates(
const std::map<std::string, std::vector<int>>&);