BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
BoostA2DE.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include <limits.h>
8#include <tuple>
9#include <folding.hpp>
10#include "BoostA2DE.h"
11
12namespace bayesnet {
13
14 BoostA2DE::BoostA2DE(bool predict_voting) : Boost(predict_voting)
15 {
16 }
17 std::vector<int> BoostA2DE::initializeModels(const Smoothing_t smoothing)
18 {
19 torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
20 std::vector<int> featuresSelected = featureSelection(weights_);
21 if (featuresSelected.size() < 2) {
22 notes.push_back("No features selected in initialization");
23 status = ERROR;
24 return std::vector<int>();
25 }
26 for (int i = 0; i < featuresSelected.size() - 1; i++) {
27 for (int j = i + 1; j < featuresSelected.size(); j++) {
28 auto parents = { featuresSelected[i], featuresSelected[j] };
29 std::unique_ptr<Classifier> model = std::make_unique<SPnDE>(parents);
30 model->fit(dataset, features, className, states, weights_, smoothing);
31 models.push_back(std::move(model));
32 significanceModels.push_back(1.0); // They will be updated later in trainModel
33 n_models++;
34 }
35 }
36 notes.push_back("Used features in initialization: " + std::to_string(featuresSelected.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
37 return featuresSelected;
38 }
39 void BoostA2DE::trainModel(const torch::Tensor& weights, const Smoothing_t smoothing)
40 {
41 //
42 // Logging setup
43 //
44 // loguru::set_thread_name("BoostA2DE");
45 // loguru::g_stderr_verbosity = loguru::Verbosity_OFF;
46 // loguru::add_file("boostA2DE.log", loguru::Truncate, loguru::Verbosity_MAX);
47
48 // Algorithm based on the adaboost algorithm for classification
49 // as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
50 fitted = true;
51 double alpha_t = 0;
52 torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
53 bool finished = false;
54 std::vector<int> featuresUsed;
55 if (selectFeatures) {
56 featuresUsed = initializeModels(smoothing);
57 if (featuresUsed.size() == 0) {
58 return;
59 }
60 auto ypred = predict(X_train);
61 std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
62 // Update significance of the models
63 for (int i = 0; i < n_models; ++i) {
64 significanceModels[i] = alpha_t;
65 }
66 if (finished) {
67 return;
68 }
69 }
70 int numItemsPack = 0; // The counter of the models inserted in the current pack
71 // Variables to control the accuracy finish condition
72 double priorAccuracy = 0.0;
73 double improvement = 1.0;
74 double convergence_threshold = 1e-4;
75 int tolerance = 0; // number of times the accuracy is lower than the convergence_threshold
76 // Step 0: Set the finish condition
77 // epsilon sub t > 0.5 => inverse the weights policy
78 // validation error is not decreasing
79 // run out of features
80 bool ascending = order_algorithm == Orders.ASC;
81 std::mt19937 g{ 173 };
82 std::vector<std::pair<int, int>> pairSelection;
83 while (!finished) {
84 // Step 1: Build ranking with mutual information
85 pairSelection = metrics.SelectKPairs(weights_, featuresUsed, ascending, 0); // Get all the pairs sorted
86 if (order_algorithm == Orders.RAND) {
87 std::shuffle(pairSelection.begin(), pairSelection.end(), g);
88 }
89 int k = bisection ? pow(2, tolerance) : 1;
90 int counter = 0; // The model counter of the current pack
91 // VLOG_SCOPE_F(1, "counter=%d k=%d featureSelection.size: %zu", counter, k, featureSelection.size());
92 while (counter++ < k && pairSelection.size() > 0) {
93 auto feature_pair = pairSelection[0];
94 pairSelection.erase(pairSelection.begin());
95 std::unique_ptr<Classifier> model;
96 model = std::make_unique<SPnDE>(std::vector<int>({ feature_pair.first, feature_pair.second }));
97 model->fit(dataset, features, className, states, weights_, smoothing);
98 alpha_t = 0.0;
99 if (!block_update) {
100 auto ypred = model->predict(X_train);
101 // Step 3.1: Compute the classifier amout of say
102 std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
103 }
104 // Step 3.4: Store classifier and its accuracy to weigh its future vote
105 numItemsPack++;
106 models.push_back(std::move(model));
107 significanceModels.push_back(alpha_t);
108 n_models++;
109 // VLOG_SCOPE_F(2, "numItemsPack: %d n_models: %d featuresUsed: %zu", numItemsPack, n_models, featuresUsed.size());
110 }
111 if (block_update) {
112 std::tie(weights_, alpha_t, finished) = update_weights_block(k, y_train, weights_);
113 }
114 if (convergence && !finished) {
115 auto y_val_predict = predict(X_test);
116 double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
117 if (priorAccuracy == 0) {
118 priorAccuracy = accuracy;
119 } else {
120 improvement = accuracy - priorAccuracy;
121 }
122 if (improvement < convergence_threshold) {
123 // VLOG_SCOPE_F(3, " (improvement<threshold) tolerance: %d numItemsPack: %d improvement: %f prior: %f current: %f", tolerance, numItemsPack, improvement, priorAccuracy, accuracy);
124 tolerance++;
125 } else {
126 // VLOG_SCOPE_F(3, "* (improvement>=threshold) Reset. tolerance: %d numItemsPack: %d improvement: %f prior: %f current: %f", tolerance, numItemsPack, improvement, priorAccuracy, accuracy);
127 tolerance = 0; // Reset the counter if the model performs better
128 numItemsPack = 0;
129 }
130 if (convergence_best) {
131 // Keep the best accuracy until now as the prior accuracy
132 priorAccuracy = std::max(accuracy, priorAccuracy);
133 } else {
134 // Keep the last accuray obtained as the prior accuracy
135 priorAccuracy = accuracy;
136 }
137 }
138 // VLOG_SCOPE_F(1, "tolerance: %d featuresUsed.size: %zu features.size: %zu", tolerance, featuresUsed.size(), features.size());
139 finished = finished || tolerance > maxTolerance || pairSelection.size() == 0;
140 }
141 if (tolerance > maxTolerance) {
142 if (numItemsPack < n_models) {
143 notes.push_back("Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated");
144 // VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models);
145 for (int i = 0; i < numItemsPack; ++i) {
146 significanceModels.pop_back();
147 models.pop_back();
148 n_models--;
149 }
150 } else {
151 notes.push_back("Convergence threshold reached & 0 models eliminated");
152 // VLOG_SCOPE_F(4, "Convergence threshold reached & 0 models eliminated n_models=%d numItemsPack=%d", n_models, numItemsPack);
153 }
154 }
155 if (pairSelection.size() > 0) {
156 notes.push_back("Pairs not used in train: " + std::to_string(pairSelection.size()));
157 status = WARNING;
158 }
159 notes.push_back("Number of models: " + std::to_string(n_models));
160 }
161 std::vector<std::string> BoostA2DE::graph(const std::string& title) const
162 {
163 return Ensemble::graph(title);
164 }
165}