BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
XBA2DE.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2025 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include <folding.hpp>
8#include <limits.h>
9#include "XBA2DE.h"
10#include "bayesnet/classifiers/XSP2DE.h"
11#include "bayesnet/utils/TensorUtils.h"
12
13namespace bayesnet {
14
15XBA2DE::XBA2DE(bool predict_voting) : Boost(predict_voting) {}
16std::vector<int> XBA2DE::initializeModels(const Smoothing_t smoothing) {
17 torch::Tensor weights_ = torch::full({m}, 1.0 / m, torch::kFloat64);
18 std::vector<int> featuresSelected = featureSelection(weights_);
19 if (featuresSelected.size() < 2) {
20 notes.push_back("No features selected in initialization");
21 status = ERROR;
22 return std::vector<int>();
23 }
24 for (int i = 0; i < featuresSelected.size() - 1; i++) {
25 for (int j = i + 1; j < featuresSelected.size(); j++) {
26 std::unique_ptr<Classifier> model = std::make_unique<XSp2de>(featuresSelected[i], featuresSelected[j]);
27 model->fit(dataset, features, className, states, weights_, smoothing);
28 add_model(std::move(model), 1.0);
29 }
30 }
31 notes.push_back("Used features in initialization: " + std::to_string(featuresSelected.size()) + " of " +
32 std::to_string(features.size()) + " with " + select_features_algorithm);
33 return featuresSelected;
34}
35void XBA2DE::trainModel(const torch::Tensor &weights, const Smoothing_t smoothing) {
36 //
37 // Logging setup
38 //
39 // loguru::set_thread_name("XBA2DE");
40 // loguru::g_stderr_verbosity = loguru::Verbosity_OFF;
41 // loguru::add_file("boostA2DE.log", loguru::Truncate, loguru::Verbosity_MAX);
42
43 // Algorithm based on the adaboost algorithm for classification
44 // as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
45 X_train_ = TensorUtils::to_matrix(X_train);
46 y_train_ = TensorUtils::to_vector<int>(y_train);
47 if (convergence) {
48 X_test_ = TensorUtils::to_matrix(X_test);
49 y_test_ = TensorUtils::to_vector<int>(y_test);
50 }
51 fitted = true;
52 double alpha_t = 0;
53 torch::Tensor weights_ = torch::full({m}, 1.0 / m, torch::kFloat64);
54 bool finished = false;
55 std::vector<int> featuresUsed;
56 if (selectFeatures) {
57 featuresUsed = initializeModels(smoothing);
58 if (featuresUsed.size() == 0) {
59 return;
60 }
61 auto ypred = predict(X_train);
62 std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
63 // Update significance of the models
64 for (int i = 0; i < n_models; ++i) {
65 significanceModels[i] = alpha_t;
66 }
67 if (finished) {
68 return;
69 }
70 }
71 int numItemsPack = 0; // The counter of the models inserted in the current pack
72 // Variables to control the accuracy finish condition
73 double priorAccuracy = 0.0;
74 double improvement = 1.0;
75 double convergence_threshold = 1e-4;
76 int tolerance = 0; // number of times the accuracy is lower than the convergence_threshold
77 // Step 0: Set the finish condition
78 // epsilon sub t > 0.5 => inverse the weights policy
79 // validation error is not decreasing
80 // run out of features
81 bool ascending = order_algorithm == Orders.ASC;
82 std::mt19937 g{173};
83 std::vector<std::pair<int, int>> pairSelection;
84 while (!finished) {
85 // Step 1: Build ranking with mutual information
86 pairSelection = metrics.SelectKPairs(weights_, featuresUsed, ascending, 0); // Get all the pairs sorted
87 if (order_algorithm == Orders.RAND) {
88 std::shuffle(pairSelection.begin(), pairSelection.end(), g);
89 }
90 int k = bisection ? pow(2, tolerance) : 1;
91 int counter = 0; // The model counter of the current pack
92 // VLOG_SCOPE_F(1, "counter=%d k=%d featureSelection.size: %zu", counter, k, featureSelection.size());
93 while (counter++ < k && pairSelection.size() > 0) {
94 auto feature_pair = pairSelection[0];
95 pairSelection.erase(pairSelection.begin());
96 std::unique_ptr<Classifier> model;
97 model = std::make_unique<XSp2de>(feature_pair.first, feature_pair.second);
98 model->fit(dataset, features, className, states, weights_, smoothing);
99 alpha_t = 0.0;
100 if (!block_update) {
101 auto ypred = model->predict(X_train);
102 // Step 3.1: Compute the classifier amout of say
103 std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
104 }
105 // Step 3.4: Store classifier and its accuracy to weigh its future vote
106 numItemsPack++;
107 models.push_back(std::move(model));
108 significanceModels.push_back(alpha_t);
109 n_models++;
110 // VLOG_SCOPE_F(2, "numItemsPack: %d n_models: %d featuresUsed: %zu", numItemsPack, n_models,
111 // featuresUsed.size());
112 }
113 if (block_update) {
114 std::tie(weights_, alpha_t, finished) = update_weights_block(k, y_train, weights_);
115 }
116 if (convergence && !finished) {
117 auto y_val_predict = predict(X_test);
118 double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
119 if (priorAccuracy == 0) {
120 priorAccuracy = accuracy;
121 } else {
122 improvement = accuracy - priorAccuracy;
123 }
124 if (improvement < convergence_threshold) {
125 // VLOG_SCOPE_F(3, " (improvement<threshold) tolerance: %d numItemsPack: %d improvement: %f prior: %f
126 // current: %f", tolerance, numItemsPack, improvement, priorAccuracy, accuracy);
127 tolerance++;
128 } else {
129 // VLOG_SCOPE_F(3, "* (improvement>=threshold) Reset. tolerance: %d numItemsPack: %d improvement: %f
130 // prior: %f current: %f", tolerance, numItemsPack, improvement, priorAccuracy, accuracy);
131 tolerance = 0; // Reset the counter if the model performs better
132 numItemsPack = 0;
133 }
134 if (convergence_best) {
135 // Keep the best accuracy until now as the prior accuracy
136 priorAccuracy = std::max(accuracy, priorAccuracy);
137 } else {
138 // Keep the last accuray obtained as the prior accuracy
139 priorAccuracy = accuracy;
140 }
141 }
142 // VLOG_SCOPE_F(1, "tolerance: %d featuresUsed.size: %zu features.size: %zu", tolerance, featuresUsed.size(),
143 // features.size());
144 finished = finished || tolerance > maxTolerance || pairSelection.size() == 0;
145 }
146 if (tolerance > maxTolerance) {
147 if (numItemsPack < n_models) {
148 notes.push_back("Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated");
149 // VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models);
150 for (int i = 0; i < numItemsPack; ++i) {
151 significanceModels.pop_back();
152 models.pop_back();
153 n_models--;
154 }
155 } else {
156 notes.push_back("Convergence threshold reached & 0 models eliminated");
157 // VLOG_SCOPE_F(4, "Convergence threshold reached & 0 models eliminated n_models=%d numItemsPack=%d",
158 // n_models, numItemsPack);
159 }
160 }
161 if (pairSelection.size() > 0) {
162 notes.push_back("Pairs not used in train: " + std::to_string(pairSelection.size()));
163 status = WARNING;
164 }
165 notes.push_back("Number of models: " + std::to_string(n_models));
166}
167std::vector<std::string> XBA2DE::graph(const std::string &title) const { return Ensemble::graph(title); }
168} // namespace bayesnet