BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
Boost.h
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#ifndef BOOST_H
8#define BOOST_H
9#include <string>
10#include <tuple>
11#include <vector>
12#include <nlohmann/json.hpp>
13#include <torch/torch.h>
14#include "Ensemble.h"
15#include "bayesnet/feature_selection/FeatureSelect.h"
16namespace bayesnet {
17 const struct {
18 std::string CFS = "CFS";
19 std::string FCBF = "FCBF";
20 std::string IWSS = "IWSS";
21 }SelectFeatures;
22 const struct {
23 std::string ASC = "asc";
24 std::string DESC = "desc";
25 std::string RAND = "rand";
26 }Orders;
27 class Boost : public Ensemble {
28 public:
29 explicit Boost(bool predict_voting = false);
30 virtual ~Boost() override = default;
31 void setHyperparameters(const nlohmann::json& hyperparameters_) override;
32 protected:
33 std::vector<int> featureSelection(torch::Tensor& weights_);
34 void buildModel(const torch::Tensor& weights) override;
35 std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights);
36 std::tuple<torch::Tensor&, double, bool> update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights);
37 void add_model(std::unique_ptr<Classifier> model, double significance);
38 void remove_last_model();
39 //
40 // Attributes
41 //
42 torch::Tensor X_train, y_train, X_test, y_test;
43 // Hyperparameters
44 bool bisection = true; // if true, use bisection stratety to add k models at once to the ensemble
45 int maxTolerance = 3;
46 std::string order_algorithm = Orders.DESC; // order to process the KBest features asc, desc, rand
47 bool convergence = true; //if true, stop when the model does not improve
48 bool convergence_best = false; // wether to keep the best accuracy to the moment or the last accuracy as prior accuracy
49 bool selectFeatures = false; // if true, use feature selection
50 std::string select_features_algorithm; // Selected feature selection algorithm
51 FeatureSelect* featureSelector = nullptr;
52 double threshold = -1;
53 bool block_update = false; // if true, use block update algorithm, only meaningful if bisection is true
54 bool alpha_block = false; // if true, the alpha is computed with the ensemble built so far and the new model
55 };
56}
57#endif