BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
Network.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include <thread>
8#include <sstream>
9#include <numeric>
10#include <algorithm>
11#include "Network.h"
12#include "bayesnet/utils/bayesnetUtils.h"
13#include "bayesnet/utils/CountingSemaphore.h"
14#include <pthread.h>
15#include <fstream>
16namespace bayesnet {
17 Network::Network() : fitted{ false }, classNumStates{ 0 }
18 {
19 }
20 Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
21 fitted(other.fitted), samples(other.samples)
22 {
23 if (samples.defined())
24 samples = samples.clone();
25 for (const auto& node : other.nodes) {
26 nodes[node.first] = std::make_unique<Node>(*node.second);
27 }
28 }
29 void Network::initialize()
30 {
31 features.clear();
32 className = "";
33 classNumStates = 0;
34 fitted = false;
35 nodes.clear();
36 samples = torch::Tensor();
37 }
38 torch::Tensor& Network::getSamples()
39 {
40 return samples;
41 }
42 void Network::addNode(const std::string& name)
43 {
44 if (fitted) {
45 throw std::invalid_argument("Cannot add node to a fitted network. Initialize first.");
46 }
47 if (name == "") {
48 throw std::invalid_argument("Node name cannot be empty");
49 }
50 if (nodes.find(name) != nodes.end()) {
51 return;
52 }
53 if (find(features.begin(), features.end(), name) == features.end()) {
54 features.push_back(name);
55 }
56 nodes[name] = std::make_unique<Node>(name);
57 }
58 std::vector<std::string> Network::getFeatures() const
59 {
60 return features;
61 }
62 int Network::getClassNumStates() const
63 {
64 return classNumStates;
65 }
66 int Network::getStates() const
67 {
68 int result = 0;
69 for (auto& node : nodes) {
70 result += node.second->getNumStates();
71 }
72 return result;
73 }
74 std::string Network::getClassName() const
75 {
76 return className;
77 }
78 bool Network::isCyclic(const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
79 {
80 if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
81 {
82 visited.insert(nodeId);
83 recStack.insert(nodeId);
84 for (Node* child : nodes[nodeId]->getChildren()) {
85 if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
86 return true;
87 if (recStack.find(child->getName()) != recStack.end())
88 return true;
89 }
90 }
91 recStack.erase(nodeId); // remove node from recursion stack before function ends
92 return false;
93 }
94 void Network::addEdge(const std::string& parent, const std::string& child)
95 {
96 if (fitted) {
97 throw std::invalid_argument("Cannot add edge to a fitted network. Initialize first.");
98 }
99 if (nodes.find(parent) == nodes.end()) {
100 throw std::invalid_argument("Parent node " + parent + " does not exist");
101 }
102 if (nodes.find(child) == nodes.end()) {
103 throw std::invalid_argument("Child node " + child + " does not exist");
104 }
105 // Check if the edge is already in the graph
106 for (auto& node : nodes[parent]->getChildren()) {
107 if (node->getName() == child) {
108 throw std::invalid_argument("Edge " + parent + " -> " + child + " already exists");
109 }
110 }
111 // Temporarily add edge to check for cycles
112 nodes[parent]->addChild(nodes[child].get());
113 nodes[child]->addParent(nodes[parent].get());
114 std::unordered_set<std::string> visited;
115 std::unordered_set<std::string> recStack;
116 if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
117 {
118 // remove problematic edge
119 nodes[parent]->removeChild(nodes[child].get());
120 nodes[child]->removeParent(nodes[parent].get());
121 throw std::invalid_argument("Adding this edge forms a cycle in the graph.");
122 }
123 }
124 std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
125 {
126 return nodes;
127 }
128 void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
129 {
130 if (weights.size(0) != n_samples) {
131 throw std::invalid_argument("Weights (" + std::to_string(weights.size(0)) + ") must have the same number of elements as samples (" + std::to_string(n_samples) + ") in Network::fit");
132 }
133 if (n_samples != n_samples_y) {
134 throw std::invalid_argument("X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) + " != " + std::to_string(n_samples_y) + ")");
135 }
136 if (n_features != featureNames.size()) {
137 throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
138 }
139 if (features.size() == 0) {
140 throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
141 }
142 if (n_features != features.size() - 1) {
143 throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
144 }
145 if (find(features.begin(), features.end(), className) == features.end()) {
146 throw std::invalid_argument("Class Name not found in Network::features");
147 }
148 for (auto& feature : featureNames) {
149 if (find(features.begin(), features.end(), feature) == features.end()) {
150 throw std::invalid_argument("Feature " + feature + " not found in Network::features");
151 }
152 if (states.find(feature) == states.end()) {
153 throw std::invalid_argument("Feature " + feature + " not found in states");
154 }
155 }
156 }
157 void Network::setStates(const std::map<std::string, std::vector<int>>& states)
158 {
159 // Set states to every Node in the network
160 for_each(features.begin(), features.end(), [this, &states](const std::string& feature) {
161 nodes.at(feature)->setNumStates(states.at(feature).size());
162 });
163 classNumStates = nodes.at(className)->getNumStates();
164 }
165 // X comes in nxm, where n is the number of features and m the number of samples
166 void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing)
167 {
168 checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
169 this->className = className;
170 torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
171 samples = torch::cat({ X , ytmp }, 0);
172 for (int i = 0; i < featureNames.size(); ++i) {
173 auto row_feature = X.index({ i, "..." });
174 }
175 completeFit(states, weights, smoothing);
176 }
177 void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing)
178 {
179 checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
180 this->className = className;
181 this->samples = samples;
182 completeFit(states, weights, smoothing);
183 }
184 // input_data comes in nxm, where n is the number of features and m the number of samples
185 void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing)
186 {
187 const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
188 checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
189 this->className = className;
190 // Build tensor of samples (nxm) (n+1 because of the class)
191 samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
192 for (int i = 0; i < featureNames.size(); ++i) {
193 samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
194 }
195 samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
196 completeFit(states, weights, smoothing);
197 }
198 void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing)
199 {
200 setStates(states);
201 std::vector<std::thread> threads;
202 auto& semaphore = CountingSemaphore::getInstance();
203 const double n_samples = static_cast<double>(samples.size(1));
204 auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node, int i) {
205 std::string threadName = "FitWorker-" + std::to_string(i);
206#if defined(__linux__)
207 pthread_setname_np(pthread_self(), threadName.c_str());
208#else
209 pthread_setname_np(threadName.c_str());
210#endif
211 double numStates = static_cast<double>(node.second->getNumStates());
212 double smoothing_factor;
213 switch (smoothing) {
214 case Smoothing_t::ORIGINAL:
215 smoothing_factor = 1.0 / n_samples;
216 break;
217 case Smoothing_t::LAPLACE:
218 smoothing_factor = 1.0;
219 break;
220 case Smoothing_t::CESTNIK:
221 smoothing_factor = 1 / numStates;
222 break;
223 default:
224 smoothing_factor = 0.0; // No smoothing
225 }
226 node.second->computeCPT(samples, features, smoothing_factor, weights);
227 semaphore.release();
228 };
229 int i = 0;
230 for (auto& node : nodes) {
231 semaphore.acquire();
232 threads.emplace_back(worker, std::ref(node), i++);
233 }
234 for (auto& thread : threads) {
235 thread.join();
236 }
237 fitted = true;
238 }
239 torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
240 {
241 if (!fitted) {
242 throw std::logic_error("You must call fit() before calling predict()");
243 }
244 // Ensure the sample size is equal to the number of features
245 if (samples.size(0) != features.size() - 1) {
246 throw std::invalid_argument("(T) Sample size (" + std::to_string(samples.size(0)) +
247 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
248 }
249 torch::Tensor result;
250 std::vector<std::thread> threads;
251 std::mutex mtx;
252 auto& semaphore = CountingSemaphore::getInstance();
253 result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
254 auto worker = [&](const torch::Tensor& sample, int i) {
255 std::string threadName = "PredictWorker-" + std::to_string(i);
256#if defined(__linux__)
257 pthread_setname_np(pthread_self(), threadName.c_str());
258#else
259 pthread_setname_np(threadName.c_str());
260#endif
261 auto psample = predict_sample(sample);
262 auto temp = torch::tensor(psample, torch::kFloat64);
263 {
264 std::lock_guard<std::mutex> lock(mtx);
265 result.index_put_({ i, "..." }, temp);
266 }
267 semaphore.release();
268 };
269 for (int i = 0; i < samples.size(1); ++i) {
270 semaphore.acquire();
271 const torch::Tensor sample = samples.index({ "...", i });
272 threads.emplace_back(worker, sample, i);
273 }
274 for (auto& thread : threads) {
275 thread.join();
276 }
277 if (proba)
278 return result;
279 return result.argmax(1);
280 }
281 // Return mxn tensor of probabilities
282 torch::Tensor Network::predict_proba(const torch::Tensor& samples)
283 {
284 return predict_tensor(samples, true);
285 }
286
287 // Return mxn tensor of probabilities
288 torch::Tensor Network::predict(const torch::Tensor& samples)
289 {
290 return predict_tensor(samples, false);
291 }
292
293 // Return mx1 std::vector of predictions
294 // tsamples is nxm std::vector of samples
295 std::vector<int> Network::predict(const std::vector<std::vector<int>>& tsamples)
296 {
297 if (!fitted) {
298 throw std::logic_error("You must call fit() before calling predict()");
299 }
300 // Ensure the sample size is equal to the number of features
301 if (tsamples.size() != features.size() - 1) {
302 throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) +
303 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
304 }
305 std::vector<int> predictions(tsamples[0].size(), 0);
306 std::vector<int> sample;
307 std::vector<std::thread> threads;
308 auto& semaphore = CountingSemaphore::getInstance();
309 auto worker = [&](const std::vector<int>& sample, const int row, int& prediction) {
310 std::string threadName = "(V)PWorker-" + std::to_string(row);
311#if defined(__linux__)
312 pthread_setname_np(pthread_self(), threadName.c_str());
313#else
314 pthread_setname_np(threadName.c_str());
315#endif
316 auto classProbabilities = predict_sample(sample);
317 auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
318 int predictedClass = distance(classProbabilities.begin(), maxElem);
319 prediction = predictedClass;
320 semaphore.release();
321 };
322 for (int row = 0; row < tsamples[0].size(); ++row) {
323 sample.clear();
324 for (int col = 0; col < tsamples.size(); ++col) {
325 sample.push_back(tsamples[col][row]);
326 }
327 semaphore.acquire();
328 threads.emplace_back(worker, sample, row, std::ref(predictions[row]));
329 }
330 for (auto& thread : threads) {
331 thread.join();
332 }
333 return predictions;
334 }
335 // Return mxn std::vector of probabilities
336 // tsamples is nxm std::vector of samples
337 std::vector<std::vector<double>> Network::predict_proba(const std::vector<std::vector<int>>& tsamples)
338 {
339 if (!fitted) {
340 throw std::logic_error("You must call fit() before calling predict_proba()");
341 }
342 // Ensure the sample size is equal to the number of features
343 if (tsamples.size() != features.size() - 1) {
344 throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) +
345 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
346 }
347 std::vector<std::vector<double>> predictions(tsamples[0].size(), std::vector<double>(classNumStates, 0.0));
348 std::vector<int> sample;
349 std::vector<std::thread> threads;
350 auto& semaphore = CountingSemaphore::getInstance();
351 auto worker = [&](const std::vector<int>& sample, int row, std::vector<double>& predictions) {
352 std::string threadName = "(V)PWorker-" + std::to_string(row);
353#if defined(__linux__)
354 pthread_setname_np(pthread_self(), threadName.c_str());
355#else
356 pthread_setname_np(threadName.c_str());
357#endif
358 std::vector<double> classProbabilities = predict_sample(sample);
359 predictions = classProbabilities;
360 semaphore.release();
361 };
362 for (int row = 0; row < tsamples[0].size(); ++row) {
363 sample.clear();
364 for (int col = 0; col < tsamples.size(); ++col) {
365 sample.push_back(tsamples[col][row]);
366 }
367 semaphore.acquire();
368 threads.emplace_back(worker, sample, row, std::ref(predictions[row]));
369 }
370 for (auto& thread : threads) {
371 thread.join();
372 }
373 return predictions;
374 }
375 double Network::score(const std::vector<std::vector<int>>& tsamples, const std::vector<int>& labels)
376 {
377 std::vector<int> y_pred = predict(tsamples);
378 int correct = 0;
379 for (int i = 0; i < y_pred.size(); ++i) {
380 if (y_pred[i] == labels[i]) {
381 correct++;
382 }
383 }
384 return (double)correct / y_pred.size();
385 }
386 // Return 1xn std::vector of probabilities
387 std::vector<double> Network::predict_sample(const std::vector<int>& sample)
388 {
389 std::map<std::string, int> evidence;
390 for (int i = 0; i < sample.size(); ++i) {
391 evidence[features[i]] = sample[i];
392 }
393 return exactInference(evidence);
394 }
395 // Return 1xn std::vector of probabilities
396 std::vector<double> Network::predict_sample(const torch::Tensor& sample)
397 {
398 std::map<std::string, int> evidence;
399 for (int i = 0; i < sample.size(0); ++i) {
400 evidence[features[i]] = sample[i].item<int>();
401 }
402 return exactInference(evidence);
403 }
404 std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
405 {
406 std::vector<double> result(classNumStates, 0.0);
407 auto completeEvidence = std::map<std::string, int>(evidence);
408 for (int i = 0; i < classNumStates; ++i) {
409 completeEvidence[getClassName()] = i;
410 double partial = 1.0;
411 for (auto& node : getNodes()) {
412 partial *= node.second->getFactorValue(completeEvidence);
413 }
414 result[i] = partial;
415 }
416 // Normalize result
417 double sum = std::accumulate(result.begin(), result.end(), 0.0);
418 transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
419 return result;
420 }
421 std::vector<std::string> Network::show() const
422 {
423 std::vector<std::string> result;
424 // Draw the network
425 for (auto& node : nodes) {
426 std::string line = node.first + " -> ";
427 for (auto child : node.second->getChildren()) {
428 line += child->getName() + ", ";
429 }
430 result.push_back(line);
431 }
432 return result;
433 }
434 std::vector<std::string> Network::graph(const std::string& title) const
435 {
436 auto output = std::vector<std::string>();
437 auto prefix = "digraph BayesNet {\nlabel=<BayesNet ";
438 auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
439 std::string header = prefix + title + suffix;
440 output.push_back(header);
441 for (auto& node : nodes) {
442 auto result = node.second->graph(className);
443 output.insert(output.end(), result.begin(), result.end());
444 }
445 output.push_back("}\n");
446 return output;
447 }
448 std::vector<std::pair<std::string, std::string>> Network::getEdges() const
449 {
450 auto edges = std::vector<std::pair<std::string, std::string>>();
451 for (const auto& node : nodes) {
452 auto head = node.first;
453 for (const auto& child : node.second->getChildren()) {
454 auto tail = child->getName();
455 edges.push_back({ head, tail });
456 }
457 }
458 return edges;
459 }
460 int Network::getNumEdges() const
461 {
462 return getEdges().size();
463 }
464 std::vector<std::string> Network::topological_sort()
465 {
466 /* Check if al the fathers of every node are before the node */
467 auto result = features;
468 result.erase(remove(result.begin(), result.end(), className), result.end());
469 bool ending{ false };
470 while (!ending) {
471 ending = true;
472 for (auto feature : features) {
473 auto fathers = nodes[feature]->getParents();
474 for (const auto& father : fathers) {
475 auto fatherName = father->getName();
476 if (fatherName == className) {
477 continue;
478 }
479 // Check if father is placed before the actual feature
480 auto it = find(result.begin(), result.end(), fatherName);
481 if (it != result.end()) {
482 auto it2 = find(result.begin(), result.end(), feature);
483 if (it2 != result.end()) {
484 if (distance(it, it2) < 0) {
485 // if it is not, insert it before the feature
486 result.erase(remove(result.begin(), result.end(), fatherName), result.end());
487 result.insert(it2, fatherName);
488 ending = false;
489 }
490 }
491 }
492 }
493 }
494 }
495 return result;
496 }
497 std::string Network::dump_cpt() const
498 {
499 std::stringstream oss;
500 for (auto& node : nodes) {
501 oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
502 oss << node.second->getCPT() << std::endl;
503 }
504 return oss.str();
505 }
506}