12#include "bayesnet/utils/bayesnetUtils.h"
13#include "bayesnet/utils/CountingSemaphore.h"
17 Network::Network() : fitted{ false }, classNumStates{ 0 }
20 Network::Network(
const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
21 fitted(other.fitted), samples(other.samples)
23 if (samples.defined())
24 samples = samples.clone();
25 for (
const auto& node : other.nodes) {
26 nodes[node.first] = std::make_unique<Node>(*node.second);
29 void Network::initialize()
36 samples = torch::Tensor();
38 torch::Tensor& Network::getSamples()
42 void Network::addNode(
const std::string& name)
45 throw std::invalid_argument(
"Cannot add node to a fitted network. Initialize first.");
48 throw std::invalid_argument(
"Node name cannot be empty");
50 if (nodes.find(name) != nodes.end()) {
53 if (find(features.begin(), features.end(), name) == features.end()) {
54 features.push_back(name);
56 nodes[name] = std::make_unique<Node>(name);
58 std::vector<std::string> Network::getFeatures()
const
62 int Network::getClassNumStates()
const
64 return classNumStates;
66 int Network::getStates()
const
69 for (
auto& node : nodes) {
70 result += node.second->getNumStates();
74 std::string Network::getClassName()
const
78 bool Network::isCyclic(
const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
80 if (visited.find(nodeId) == visited.end())
82 visited.insert(nodeId);
83 recStack.insert(nodeId);
84 for (Node* child : nodes[nodeId]->getChildren()) {
85 if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
87 if (recStack.find(child->getName()) != recStack.end())
91 recStack.erase(nodeId);
94 void Network::addEdge(
const std::string& parent,
const std::string& child)
97 throw std::invalid_argument(
"Cannot add edge to a fitted network. Initialize first.");
99 if (nodes.find(parent) == nodes.end()) {
100 throw std::invalid_argument(
"Parent node " + parent +
" does not exist");
102 if (nodes.find(child) == nodes.end()) {
103 throw std::invalid_argument(
"Child node " + child +
" does not exist");
106 for (
auto& node : nodes[parent]->getChildren()) {
107 if (node->getName() == child) {
108 throw std::invalid_argument(
"Edge " + parent +
" -> " + child +
" already exists");
112 nodes[parent]->addChild(nodes[child].get());
113 nodes[child]->addParent(nodes[parent].get());
114 std::unordered_set<std::string> visited;
115 std::unordered_set<std::string> recStack;
116 if (isCyclic(nodes[child]->getName(), visited, recStack))
119 nodes[parent]->removeChild(nodes[child].get());
120 nodes[child]->removeParent(nodes[parent].get());
121 throw std::invalid_argument(
"Adding this edge forms a cycle in the graph.");
124 std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
128 void Network::checkFitData(
int n_samples,
int n_features,
int n_samples_y,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states,
const torch::Tensor& weights)
130 if (weights.size(0) != n_samples) {
131 throw std::invalid_argument(
"Weights (" + std::to_string(weights.size(0)) +
") must have the same number of elements as samples (" + std::to_string(n_samples) +
") in Network::fit");
133 if (n_samples != n_samples_y) {
134 throw std::invalid_argument(
"X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) +
" != " + std::to_string(n_samples_y) +
")");
136 if (n_features != featureNames.size()) {
137 throw std::invalid_argument(
"X and features must have the same number of features in Network::fit (" + std::to_string(n_features) +
" != " + std::to_string(featureNames.size()) +
")");
139 if (features.size() == 0) {
140 throw std::invalid_argument(
"The network has not been initialized. You must call addNode() before calling fit()");
142 if (n_features != features.size() - 1) {
143 throw std::invalid_argument(
"X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) +
" != " + std::to_string(features.size() - 1) +
")");
145 if (find(features.begin(), features.end(), className) == features.end()) {
146 throw std::invalid_argument(
"Class Name not found in Network::features");
148 for (
auto& feature : featureNames) {
149 if (find(features.begin(), features.end(), feature) == features.end()) {
150 throw std::invalid_argument(
"Feature " + feature +
" not found in Network::features");
152 if (states.find(feature) == states.end()) {
153 throw std::invalid_argument(
"Feature " + feature +
" not found in states");
157 void Network::setStates(
const std::map<std::string, std::vector<int>>& states)
160 for_each(features.begin(), features.end(), [
this, &states](
const std::string& feature) {
161 nodes.at(feature)->setNumStates(states.at(feature).size());
163 classNumStates = nodes.at(className)->getNumStates();
166 void Network::fit(
const torch::Tensor& X,
const torch::Tensor& y,
const torch::Tensor& weights,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states,
const Smoothing_t smoothing)
168 checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
169 this->className = className;
170 torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
171 samples = torch::cat({ X , ytmp }, 0);
172 for (
int i = 0; i < featureNames.size(); ++i) {
173 auto row_feature = X.index({ i,
"..." });
175 completeFit(states, weights, smoothing);
177 void Network::fit(
const torch::Tensor& samples,
const torch::Tensor& weights,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states,
const Smoothing_t smoothing)
179 checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
180 this->className = className;
181 this->samples = samples;
182 completeFit(states, weights, smoothing);
185 void Network::fit(
const std::vector<std::vector<int>>& input_data,
const std::vector<int>& labels,
const std::vector<double>& weights_,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states,
const Smoothing_t smoothing)
187 const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
188 checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
189 this->className = className;
191 samples = torch::zeros({
static_cast<int>(input_data.size() + 1),
static_cast<int>(input_data[0].size()) }, torch::kInt32);
192 for (
int i = 0; i < featureNames.size(); ++i) {
193 samples.index_put_({ i,
"..." }, torch::tensor(input_data[i], torch::kInt32));
195 samples.index_put_({ -1,
"..." }, torch::tensor(labels, torch::kInt32));
196 completeFit(states, weights, smoothing);
198 void Network::completeFit(
const std::map<std::string, std::vector<int>>& states,
const torch::Tensor& weights,
const Smoothing_t smoothing)
201 std::vector<std::thread> threads;
202 auto& semaphore = CountingSemaphore::getInstance();
203 const double n_samples =
static_cast<double>(samples.size(1));
204 auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node,
int i) {
205 std::string threadName =
"FitWorker-" + std::to_string(i);
206#if defined(__linux__)
207 pthread_setname_np(pthread_self(), threadName.c_str());
209 pthread_setname_np(threadName.c_str());
211 double numStates =
static_cast<double>(node.second->getNumStates());
212 double smoothing_factor;
214 case Smoothing_t::ORIGINAL:
215 smoothing_factor = 1.0 / n_samples;
217 case Smoothing_t::LAPLACE:
218 smoothing_factor = 1.0;
220 case Smoothing_t::CESTNIK:
221 smoothing_factor = 1 / numStates;
224 smoothing_factor = 0.0;
226 node.second->computeCPT(samples, features, smoothing_factor, weights);
230 for (
auto& node : nodes) {
232 threads.emplace_back(worker, std::ref(node), i++);
234 for (
auto& thread : threads) {
239 torch::Tensor Network::predict_tensor(
const torch::Tensor& samples,
const bool proba)
242 throw std::logic_error(
"You must call fit() before calling predict()");
245 if (samples.size(0) != features.size() - 1) {
246 throw std::invalid_argument(
"(T) Sample size (" + std::to_string(samples.size(0)) +
247 ") does not match the number of features (" + std::to_string(features.size() - 1) +
")");
249 torch::Tensor result;
250 std::vector<std::thread> threads;
252 auto& semaphore = CountingSemaphore::getInstance();
253 result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
254 auto worker = [&](
const torch::Tensor& sample,
int i) {
255 std::string threadName =
"PredictWorker-" + std::to_string(i);
256#if defined(__linux__)
257 pthread_setname_np(pthread_self(), threadName.c_str());
259 pthread_setname_np(threadName.c_str());
261 auto psample = predict_sample(sample);
262 auto temp = torch::tensor(psample, torch::kFloat64);
264 std::lock_guard<std::mutex> lock(mtx);
265 result.index_put_({ i,
"..." }, temp);
269 for (
int i = 0; i < samples.size(1); ++i) {
271 const torch::Tensor sample = samples.index({
"...", i });
272 threads.emplace_back(worker, sample, i);
274 for (
auto& thread : threads) {
279 return result.argmax(1);
282 torch::Tensor Network::predict_proba(
const torch::Tensor& samples)
284 return predict_tensor(samples,
true);
288 torch::Tensor Network::predict(
const torch::Tensor& samples)
290 return predict_tensor(samples,
false);
295 std::vector<int> Network::predict(
const std::vector<std::vector<int>>& tsamples)
298 throw std::logic_error(
"You must call fit() before calling predict()");
301 if (tsamples.size() != features.size() - 1) {
302 throw std::invalid_argument(
"(V) Sample size (" + std::to_string(tsamples.size()) +
303 ") does not match the number of features (" + std::to_string(features.size() - 1) +
")");
305 std::vector<int> predictions(tsamples[0].size(), 0);
306 std::vector<int> sample;
307 std::vector<std::thread> threads;
308 auto& semaphore = CountingSemaphore::getInstance();
309 auto worker = [&](
const std::vector<int>& sample,
const int row,
int& prediction) {
310 std::string threadName =
"(V)PWorker-" + std::to_string(row);
311#if defined(__linux__)
312 pthread_setname_np(pthread_self(), threadName.c_str());
314 pthread_setname_np(threadName.c_str());
316 auto classProbabilities = predict_sample(sample);
317 auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
318 int predictedClass = distance(classProbabilities.begin(), maxElem);
319 prediction = predictedClass;
322 for (
int row = 0; row < tsamples[0].size(); ++row) {
324 for (
int col = 0; col < tsamples.size(); ++col) {
325 sample.push_back(tsamples[col][row]);
328 threads.emplace_back(worker, sample, row, std::ref(predictions[row]));
330 for (
auto& thread : threads) {
337 std::vector<std::vector<double>> Network::predict_proba(
const std::vector<std::vector<int>>& tsamples)
340 throw std::logic_error(
"You must call fit() before calling predict_proba()");
343 if (tsamples.size() != features.size() - 1) {
344 throw std::invalid_argument(
"(V) Sample size (" + std::to_string(tsamples.size()) +
345 ") does not match the number of features (" + std::to_string(features.size() - 1) +
")");
347 std::vector<std::vector<double>> predictions(tsamples[0].size(), std::vector<double>(classNumStates, 0.0));
348 std::vector<int> sample;
349 std::vector<std::thread> threads;
350 auto& semaphore = CountingSemaphore::getInstance();
351 auto worker = [&](
const std::vector<int>& sample,
int row, std::vector<double>& predictions) {
352 std::string threadName =
"(V)PWorker-" + std::to_string(row);
353#if defined(__linux__)
354 pthread_setname_np(pthread_self(), threadName.c_str());
356 pthread_setname_np(threadName.c_str());
358 std::vector<double> classProbabilities = predict_sample(sample);
359 predictions = classProbabilities;
362 for (
int row = 0; row < tsamples[0].size(); ++row) {
364 for (
int col = 0; col < tsamples.size(); ++col) {
365 sample.push_back(tsamples[col][row]);
368 threads.emplace_back(worker, sample, row, std::ref(predictions[row]));
370 for (
auto& thread : threads) {
375 double Network::score(
const std::vector<std::vector<int>>& tsamples,
const std::vector<int>& labels)
377 std::vector<int> y_pred = predict(tsamples);
379 for (
int i = 0; i < y_pred.size(); ++i) {
380 if (y_pred[i] == labels[i]) {
384 return (
double)correct / y_pred.size();
387 std::vector<double> Network::predict_sample(
const std::vector<int>& sample)
389 std::map<std::string, int> evidence;
390 for (
int i = 0; i < sample.size(); ++i) {
391 evidence[features[i]] = sample[i];
393 return exactInference(evidence);
396 std::vector<double> Network::predict_sample(
const torch::Tensor& sample)
398 std::map<std::string, int> evidence;
399 for (
int i = 0; i < sample.size(0); ++i) {
400 evidence[features[i]] = sample[i].item<
int>();
402 return exactInference(evidence);
404 std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
406 std::vector<double> result(classNumStates, 0.0);
407 auto completeEvidence = std::map<std::string, int>(evidence);
408 for (
int i = 0; i < classNumStates; ++i) {
409 completeEvidence[getClassName()] = i;
410 double partial = 1.0;
411 for (
auto& node : getNodes()) {
412 partial *= node.second->getFactorValue(completeEvidence);
417 double sum = std::accumulate(result.begin(), result.end(), 0.0);
418 transform(result.begin(), result.end(), result.begin(), [sum](
const double& value) { return value / sum; });
421 std::vector<std::string> Network::show()
const
423 std::vector<std::string> result;
425 for (
auto& node : nodes) {
426 std::string line = node.first +
" -> ";
427 for (
auto child : node.second->getChildren()) {
428 line += child->getName() +
", ";
430 result.push_back(line);
434 std::vector<std::string> Network::graph(
const std::string& title)
const
436 auto output = std::vector<std::string>();
437 auto prefix =
"digraph BayesNet {\nlabel=<BayesNet ";
438 auto suffix =
">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
439 std::string header = prefix + title + suffix;
440 output.push_back(header);
441 for (
auto& node : nodes) {
442 auto result = node.second->graph(className);
443 output.insert(output.end(), result.begin(), result.end());
445 output.push_back(
"}\n");
448 std::vector<std::pair<std::string, std::string>> Network::getEdges()
const
450 auto edges = std::vector<std::pair<std::string, std::string>>();
451 for (
const auto& node : nodes) {
452 auto head = node.first;
453 for (
const auto& child : node.second->getChildren()) {
454 auto tail = child->getName();
455 edges.push_back({ head, tail });
460 int Network::getNumEdges()
const
462 return getEdges().size();
464 std::vector<std::string> Network::topological_sort()
467 auto result = features;
468 result.erase(remove(result.begin(), result.end(), className), result.end());
469 bool ending{
false };
472 for (
auto feature : features) {
473 auto fathers = nodes[feature]->getParents();
474 for (
const auto& father : fathers) {
475 auto fatherName = father->getName();
476 if (fatherName == className) {
480 auto it = find(result.begin(), result.end(), fatherName);
481 if (it != result.end()) {
482 auto it2 = find(result.begin(), result.end(), feature);
483 if (it2 != result.end()) {
484 if (distance(it, it2) < 0) {
486 result.erase(remove(result.begin(), result.end(), fatherName), result.end());
487 result.insert(it2, fatherName);
497 std::string Network::dump_cpt()
const
499 std::stringstream oss;
500 for (
auto& node : nodes) {
501 oss <<
"* " << node.first <<
": (" << node.second->getNumStates() <<
") : " << node.second->getCPT().sizes() << std::endl;
502 oss << node.second->getCPT() << std::endl;