BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
KDB.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6#include "bayesnet/utils/bayesnetUtils.h"
7#include "KDB.h"
8
9namespace bayesnet {
10 KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta)
11 {
12 validHyperparameters = { "k", "theta" };
13
14 }
15 void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
16 {
17 auto hyperparameters = hyperparameters_;
18 if (hyperparameters.contains("k")) {
19 k = hyperparameters["k"];
20 hyperparameters.erase("k");
21 }
22 if (hyperparameters.contains("theta")) {
23 theta = hyperparameters["theta"];
24 hyperparameters.erase("theta");
25 }
26 Classifier::setHyperparameters(hyperparameters);
27 }
28 void KDB::buildModel(const torch::Tensor& weights)
29 {
30 /*
31 1. For each feature Xi, compute mutual information, I(X;C),
32 where C is the class.
33 2. Compute class conditional mutual information I(Xi;XjIC), f or each
34 pair of features Xi and Xj, where i#j.
35 3. Let the used variable list, S, be empty.
36 4. Let the DAG network being constructed, BN, begin with a single
37 class node, C.
38 5. Repeat until S includes all domain features
39 5.1. Select feature Xmax which is not in S and has the largest value
40 I(Xmax;C).
41 5.2. Add a node to BN representing Xmax.
42 5.3. Add an arc from C to Xmax in BN.
43 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
44 the highest value for I(Xmax;X,jC).
45 5.5. Add Xmax to S.
46 Compute the conditional probabilility infered by the structure of BN by
47 using counts from DB, and output BN.
48 */
49 // 1. For each feature Xi, compute mutual information, I(X;C),
50 // where C is the class.
51 addNodes();
52 const torch::Tensor& y = dataset.index({ -1, "..." });
53 std::vector<double> mi;
54 for (auto i = 0; i < features.size(); i++) {
55 torch::Tensor firstFeature = dataset.index({ i, "..." });
56 mi.push_back(metrics.mutualInformation(firstFeature, y, weights));
57 }
58 // 2. Compute class conditional mutual information I(Xi;XjIC), f or each
59 auto conditionalEdgeWeights = metrics.conditionalEdge(weights);
60 // 3. Let the used variable list, S, be empty.
61 std::vector<int> S;
62 // 4. Let the DAG network being constructed, BN, begin with a single
63 // class node, C.
64 // 5. Repeat until S includes all domain features
65 // 5.1. Select feature Xmax which is not in S and has the largest value
66 // I(Xmax;C).
67 auto order = argsort(mi);
68 for (auto idx : order) {
69 // 5.2. Add a node to BN representing Xmax.
70 // 5.3. Add an arc from C to Xmax in BN.
71 model.addEdge(className, features[idx]);
72 // 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
73 // the highest value for I(Xmax;X,jC).
74 add_m_edges(idx, S, conditionalEdgeWeights);
75 // 5.5. Add Xmax to S.
76 S.push_back(idx);
77 }
78 }
79 void KDB::add_m_edges(int idx, std::vector<int>& S, torch::Tensor& weights)
80 {
81 auto n_edges = std::min(k, static_cast<int>(S.size()));
82 auto cond_w = clone(weights);
83 bool exit_cond = k == 0;
84 int num = 0;
85 while (!exit_cond) {
86 auto max_minfo = argmax(cond_w.index({ idx, "..." })).item<int>();
87 auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
88 if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
89 try {
90 model.addEdge(features[max_minfo], features[idx]);
91 num++;
92 }
93 catch (const std::invalid_argument& e) {
94 // Loops are not allowed
95 }
96 }
97 cond_w.index_put_({ idx, max_minfo }, -1);
98 auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta);
99 auto candidates = candidates_mask.nonzero();
100 exit_cond = num == n_edges || candidates.size(0) == 0;
101 }
102 }
103 std::vector<std::string> KDB::graph(const std::string& title) const
104 {
105 std::string header{ title };
106 if (title == "KDB") {
107 header += " (k=" + std::to_string(k) + ", theta=" + std::to_string(theta) + ")";
108 }
109 return model.graph(header);
110 }
111}