7#include "bayesnet/feature_selection/CFS.h"
8#include "bayesnet/feature_selection/FCBF.h"
9#include "bayesnet/feature_selection/IWSS.h"
13Boost::Boost(
bool predict_voting) : Ensemble(predict_voting) {
14 validHyperparameters = {
"alpha_block",
"order",
"convergence",
"convergence_best",
"bisection",
15 "threshold",
"maxTolerance",
"predict_voting",
"select_features",
"block_update"};
17void Boost::setHyperparameters(
const nlohmann::json &hyperparameters_) {
18 auto hyperparameters = hyperparameters_;
19 if (hyperparameters.contains(
"order")) {
20 std::vector<std::string> algos = {Orders.ASC, Orders.DESC, Orders.RAND};
21 order_algorithm = hyperparameters[
"order"];
22 if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) {
23 throw std::invalid_argument(
"Invalid order algorithm, valid values [" + Orders.ASC +
", " + Orders.DESC +
24 ", " + Orders.RAND +
"]");
26 hyperparameters.erase(
"order");
28 if (hyperparameters.contains(
"alpha_block")) {
29 alpha_block = hyperparameters[
"alpha_block"];
30 hyperparameters.erase(
"alpha_block");
32 if (hyperparameters.contains(
"convergence")) {
33 convergence = hyperparameters[
"convergence"];
34 hyperparameters.erase(
"convergence");
36 if (hyperparameters.contains(
"convergence_best")) {
37 convergence_best = hyperparameters[
"convergence_best"];
38 hyperparameters.erase(
"convergence_best");
40 if (hyperparameters.contains(
"bisection")) {
41 bisection = hyperparameters[
"bisection"];
42 hyperparameters.erase(
"bisection");
44 if (hyperparameters.contains(
"threshold")) {
45 threshold = hyperparameters[
"threshold"];
46 hyperparameters.erase(
"threshold");
48 if (hyperparameters.contains(
"maxTolerance")) {
49 maxTolerance = hyperparameters[
"maxTolerance"];
50 if (maxTolerance < 1 || maxTolerance > 6)
51 throw std::invalid_argument(
"Invalid maxTolerance value, must be greater in [1, 6]");
52 hyperparameters.erase(
"maxTolerance");
54 if (hyperparameters.contains(
"predict_voting")) {
55 predict_voting = hyperparameters[
"predict_voting"];
56 hyperparameters.erase(
"predict_voting");
58 if (hyperparameters.contains(
"select_features")) {
59 auto selectedAlgorithm = hyperparameters[
"select_features"];
60 std::vector<std::string> algos = {SelectFeatures.IWSS, SelectFeatures.CFS, SelectFeatures.FCBF};
61 selectFeatures =
true;
62 select_features_algorithm = selectedAlgorithm;
63 if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) {
64 throw std::invalid_argument(
"Invalid selectFeatures value, valid values [" + SelectFeatures.IWSS +
", " +
65 SelectFeatures.CFS +
", " + SelectFeatures.FCBF +
"]");
67 hyperparameters.erase(
"select_features");
69 if (hyperparameters.contains(
"block_update")) {
70 block_update = hyperparameters[
"block_update"];
71 hyperparameters.erase(
"block_update");
73 if (block_update && alpha_block) {
74 throw std::invalid_argument(
"alpha_block and block_update cannot be true at the same time");
76 if (block_update && !bisection) {
77 throw std::invalid_argument(
"block_update needs bisection to be true");
79 Classifier::setHyperparameters(hyperparameters);
81void Boost::add_model(std::unique_ptr<Classifier> model,
double significance) {
82 models.push_back(std::move(model));
84 significanceModels.push_back(significance);
86void Boost::remove_last_model() {
88 significanceModels.pop_back();
91void Boost::buildModel(
const torch::Tensor &weights) {
94 significanceModels.clear();
97 auto y_ = dataset.index({-1,
"..."});
100 auto fold = folding::StratifiedKFold(5, y_, 271);
101 auto [train, test] = fold.getFold(0);
102 auto train_t = torch::tensor(train);
103 auto test_t = torch::tensor(test);
105 X_train = dataset.index({torch::indexing::Slice(0, dataset.size(0) - 1), train_t});
106 y_train = dataset.index({-1, train_t});
107 X_test = dataset.index({torch::indexing::Slice(0, dataset.size(0) - 1), test_t});
108 y_test = dataset.index({-1, test_t});
111 auto n_classes = states.at(className).size();
113 buildDataset(y_train);
114 metrics = Metrics(dataset, features, className, n_classes);
117 X_train = dataset.index({torch::indexing::Slice(0, dataset.size(0) - 1),
"..."});
121std::vector<int> Boost::featureSelection(torch::Tensor &weights_) {
123 if (select_features_algorithm == SelectFeatures.CFS) {
124 featureSelector =
new CFS(dataset, features, className, maxFeatures, states.at(className).size(), weights_);
125 }
else if (select_features_algorithm == SelectFeatures.IWSS) {
126 if (threshold < 0 || threshold > 0.5) {
127 throw std::invalid_argument(
"Invalid threshold value for " + SelectFeatures.IWSS +
" [0, 0.5]");
130 new IWSS(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);
131 }
else if (select_features_algorithm == SelectFeatures.FCBF) {
132 if (threshold < 1e-7 || threshold > 1) {
133 throw std::invalid_argument(
"Invalid threshold value for " + SelectFeatures.FCBF +
" [1e-7, 1]");
136 new FCBF(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);
138 featureSelector->fit();
139 auto featuresUsed = featureSelector->getFeatures();
140 delete featureSelector;
143std::tuple<torch::Tensor &, double, bool> Boost::update_weights(torch::Tensor &ytrain, torch::Tensor &ypred,
144 torch::Tensor &weights) {
145 bool terminate =
false;
147 auto mask_wrong = ypred != ytrain;
148 auto mask_right = ypred == ytrain;
149 auto masked_weights = weights * mask_wrong.to(weights.dtype());
150 double epsilon_t = masked_weights.sum().item<
double>();
153 if (epsilon_t > 0.5) {
159 double wt = (1 - epsilon_t) / epsilon_t;
160 alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);
163 weights += mask_wrong.to(weights.dtype()) * exp(alpha_t) * weights;
165 weights += mask_right.to(weights.dtype()) * exp(-alpha_t) * weights;
167 double totalWeights = torch::sum(weights).item<
double>();
168 weights = weights / totalWeights;
170 return {weights, alpha_t, terminate};
172std::tuple<torch::Tensor &, double, bool> Boost::update_weights_block(
int k, torch::Tensor &ytrain,
173 torch::Tensor &weights) {
216 std::unique_ptr<Classifier> model;
217 std::vector<std::unique_ptr<Classifier>> models_bak;
219 auto significance_bak = significanceModels;
220 auto n_models_bak = n_models;
222 significanceModels = std::vector<double>(k, 1.0);
225 for (
int i = 0; i < n_models - k; ++i) {
226 model = std::move(models[0]);
227 models.erase(models.begin());
228 models_bak.push_back(std::move(model));
230 assert(models.size() == k);
234 auto ypred = predict(X_train);
240 std::tie(weights, alpha_t, terminate) = update_weights(y_train, ypred, weights);
246 if (k != n_models_bak) {
248 int bak_size = models_bak.size();
249 for (
int i = 0; i < bak_size; ++i) {
250 model = std::move(models_bak[bak_size - 1 - i]);
251 models_bak.erase(models_bak.end() - 1);
252 models.insert(models.begin(), std::move(model));
256 significanceModels = significance_bak;
261 for (
int i = 0; i < k; ++i) {
262 significanceModels[n_models_bak - k + i] = alpha_t;
265 n_models = n_models_bak;
266 return {weights, alpha_t, terminate};