BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
Network.h
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#ifndef NETWORK_H
8#define NETWORK_H
9#include <map>
10#include <vector>
11#include "bayesnet/config.h"
12#include "Node.h"
13#include "Smoothing.h"
14
15namespace bayesnet {
16
17 class Network {
18 public:
19 Network();
20 explicit Network(const Network&);
21 ~Network() = default;
22 torch::Tensor& getSamples();
23 void addNode(const std::string&);
24 void addEdge(const std::string&, const std::string&);
25 std::map<std::string, std::unique_ptr<Node>>& getNodes();
26 std::vector<std::string> getFeatures() const;
27 int getStates() const;
28 std::vector<std::pair<std::string, std::string>> getEdges() const;
29 int getNumEdges() const;
30 int getClassNumStates() const;
31 std::string getClassName() const;
32 /*
33 Notice: Nodes have to be inserted in the same order as they are in the dataset, i.e., first node is first column and so on.
34 */
35 void fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing);
36 void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing);
37 void fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing);
38 std::vector<int> predict(const std::vector<std::vector<int>>&); // Return mx1 std::vector of predictions
39 torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions
40 torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba);
41 std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>&); // Return mxn std::vector of probabilities
42 torch::Tensor predict_proba(const torch::Tensor&); // Return mxn tensor of probabilities
43 double score(const std::vector<std::vector<int>>&, const std::vector<int>&);
44 std::vector<std::string> topological_sort();
45 std::vector<std::string> show() const;
46 std::vector<std::string> graph(const std::string& title) const; // Returns a std::vector of std::strings representing the graph in graphviz format
47 void initialize();
48 std::string dump_cpt() const;
49 inline std::string version() { return { project_version.begin(), project_version.end() }; }
50 private:
51 std::map<std::string, std::unique_ptr<Node>> nodes;
52 bool fitted;
53 int classNumStates;
54 std::vector<std::string> features; // Including classname
55 std::string className;
56 torch::Tensor samples; // n+1xm tensor used to fit the model
57 bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
58 std::vector<double> predict_sample(const std::vector<int>&);
59 std::vector<double> predict_sample(const torch::Tensor&);
60 std::vector<double> exactInference(std::map<std::string, int>&);
61 void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing);
62 void checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
63 void setStates(const std::map<std::string, std::vector<int>>&);
64 };
65}
66#endif