BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
Node.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include "Node.h"
8
9namespace bayesnet {
10
11 Node::Node(const std::string& name)
12 : name(name)
13 {
14 }
15 void Node::clear()
16 {
17 parents.clear();
18 children.clear();
19 cpTable = torch::Tensor();
20 dimensions.clear();
21 numStates = 0;
22 }
23 std::string Node::getName() const
24 {
25 return name;
26 }
27 void Node::addParent(Node* parent)
28 {
29 parents.push_back(parent);
30 }
31 void Node::removeParent(Node* parent)
32 {
33 parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
34 }
35 void Node::removeChild(Node* child)
36 {
37 children.erase(std::remove(children.begin(), children.end(), child), children.end());
38 }
39 void Node::addChild(Node* child)
40 {
41 children.push_back(child);
42 }
43 std::vector<Node*>& Node::getParents()
44 {
45 return parents;
46 }
47 std::vector<Node*>& Node::getChildren()
48 {
49 return children;
50 }
51 int Node::getNumStates() const
52 {
53 return numStates;
54 }
55 void Node::setNumStates(int numStates)
56 {
57 this->numStates = numStates;
58 }
59 torch::Tensor& Node::getCPT()
60 {
61 return cpTable;
62 }
63 /*
64 The MinFill criterion is a heuristic for variable elimination.
65 The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
66 This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
67 The variable with the minimum number of edges is chosen.
68 Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
69 */
70 unsigned Node::minFill()
71 {
72 std::unordered_set<std::string> neighbors;
73 for (auto child : children) {
74 neighbors.emplace(child->getName());
75 }
76 for (auto parent : parents) {
77 neighbors.emplace(parent->getName());
78 }
79 auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
80 return combinations(source).size();
81 }
82 std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source)
83 {
84 std::vector<std::pair<std::string, std::string>> result;
85 for (int i = 0; i < source.size(); ++i) {
86 std::string temp = source[i];
87 for (int j = i + 1; j < source.size(); ++j) {
88 result.push_back({ temp, source[j] });
89 }
90 }
91 return result;
92 }
93 void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double smoothing, const torch::Tensor& weights)
94 {
95 dimensions.clear();
96 dimensions.reserve(parents.size() + 1);
97 // Get dimensions of the CPT
98 dimensions.push_back(numStates);
99 for (const auto& parent : parents) {
100 dimensions.push_back(parent->getNumStates());
101 }
102 //transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
103 // Create a tensor initialized with smoothing
104 cpTable = torch::full(dimensions, smoothing, torch::kDouble);
105 // Create a map for quick feature index lookup
106 std::unordered_map<std::string, int> featureIndexMap;
107 for (size_t i = 0; i < features.size(); ++i) {
108 featureIndexMap[features[i]] = i;
109 }
110 // Fill table with counts
111 // Get the index of this node's feature
112 int name_index = featureIndexMap[name];
113 // Get parent indices in dataset
114 std::vector<int> parent_indices;
115 parent_indices.reserve(parents.size());
116 for (const auto& parent : parents) {
117 parent_indices.push_back(featureIndexMap[parent->getName()]);
118 }
119 c10::List<c10::optional<at::Tensor>> coordinates;
120 for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
121 coordinates.clear();
122 auto sample = dataset.index({ "...", n_sample });
123 coordinates.push_back(sample[name_index]);
124 for (size_t i = 0; i < parent_indices.size(); ++i) {
125 coordinates.push_back(sample[parent_indices[i]]);
126 }
127 // Increment the count of the corresponding coordinate
128 cpTable.index_put_({ coordinates }, weights.index({ n_sample }), true);
129 }
130 // Normalize the counts (dividing each row by the sum of the row)
131 cpTable /= cpTable.sum(0, true);
132 }
133 double Node::getFactorValue(std::map<std::string, int>& evidence)
134 {
135 c10::List<c10::optional<at::Tensor>> coordinates;
136 // following predetermined order of indices in the cpTable (see Node.h)
137 coordinates.push_back(at::tensor(evidence[name]));
138 transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
139 return cpTable.index({ coordinates }).item<double>();
140 }
141 std::vector<std::string> Node::graph(const std::string& className)
142 {
143 auto output = std::vector<std::string>();
144 auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
145 output.push_back("\"" + name + "\" [shape=circle" + suffix + "] \n");
146 transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return "\"" + name + "\" -> \"" + child->getName() + "\""; });
147 return output;
148 }
149}