BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
TAN.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include "TAN.h"
8
9namespace bayesnet {
10 TAN::TAN() : Classifier(Network())
11 {
12 validHyperparameters = { "parent" };
13 }
14
15 void TAN::setHyperparameters(const nlohmann::json& hyperparameters_)
16 {
17 auto hyperparameters = hyperparameters_;
18 if (hyperparameters.contains("parent")) {
19 parent = hyperparameters["parent"];
20 hyperparameters.erase("parent");
21 }
22 Classifier::setHyperparameters(hyperparameters);
23 }
24 void TAN::buildModel(const torch::Tensor& weights)
25 {
26 // 0. Add all nodes to the model
27 addNodes();
28 // 1. Compute mutual information between each feature and the class and set the root node
29 // as the highest mutual information with the class
30 auto mi = std::vector <std::pair<int, float >>();
31 torch::Tensor class_dataset = dataset.index({ -1, "..." });
32 for (int i = 0; i < static_cast<int>(features.size()); ++i) {
33 torch::Tensor feature_dataset = dataset.index({ i, "..." });
34 auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset, weights);
35 mi.push_back({ i, mi_value });
36 }
37 sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
38 auto root = parent == -1 ? mi[mi.size() - 1].first : parent;
39 if (root >= static_cast<int>(features.size())) {
40 throw std::invalid_argument("The parent node is not in the dataset");
41 }
42 // 2. Compute mutual information between each feature and the class
43 auto weights_matrix = metrics.conditionalEdge(weights);
44 // 3. Compute the maximum spanning tree
45 auto mst = metrics.maximumSpanningTree(features, weights_matrix, root);
46 // 4. Add edges from the maximum spanning tree to the model
47 for (auto i = 0; i < mst.size(); ++i) {
48 auto [from, to] = mst[i];
49 model.addEdge(features[from], features[to]);
50 }
51 // 5. Add edges from the class to all features
52 for (auto feature : features) {
53 model.addEdge(className, feature);
54 }
55 }
56 std::vector<std::string> TAN::graph(const std::string& title) const
57 {
58 return model.graph(title);
59 }
60}