10 Ensemble::Ensemble(
bool predict_voting) : Classifier(Network()), n_models(0), predict_voting(predict_voting)
13 const std::string ENSEMBLE_NOT_FITTED =
"Ensemble has not been fitted";
14 void Ensemble::trainModel(
const torch::Tensor& weights,
const Smoothing_t smoothing)
16 n_models = models.size();
17 for (
auto i = 0; i < n_models; ++i) {
19 models[i]->fit(dataset, features, className, states, smoothing);
22 std::vector<int> Ensemble::compute_arg_max(std::vector<std::vector<double>>& X)
24 std::vector<int> y_pred;
25 for (
auto i = 0; i < X.size(); ++i) {
26 auto max = std::max_element(X[i].begin(), X[i].end());
27 y_pred.push_back(std::distance(X[i].begin(), max));
31 torch::Tensor Ensemble::compute_arg_max(torch::Tensor& X)
33 auto y_pred = torch::argmax(X, 1);
36 torch::Tensor Ensemble::voting(torch::Tensor& votes)
39 auto y_pred_ = votes.accessor<int, 2>();
40 std::vector<int> y_pred_final;
41 int numClasses = states.at(className).size();
43 auto result = torch::zeros({ votes.size(0), numClasses }, torch::kFloat32);
44 auto sum = std::reduce(significanceModels.begin(), significanceModels.end());
45 for (
int i = 0; i < votes.size(0); ++i) {
48 std::vector<double> n_votes(numClasses, 0.0);
49 for (
int j = 0; j < n_models; ++j) {
50 n_votes[y_pred_[i][j]] += significanceModels.at(j);
52 result[i] = torch::tensor(n_votes);
58 std::vector<std::vector<double>> Ensemble::predict_proba(std::vector<std::vector<int>>& X)
61 throw std::logic_error(ENSEMBLE_NOT_FITTED);
63 return predict_voting ? predict_average_voting(X) : predict_average_proba(X);
65 torch::Tensor Ensemble::predict_proba(torch::Tensor& X)
68 throw std::logic_error(ENSEMBLE_NOT_FITTED);
70 return predict_voting ? predict_average_voting(X) : predict_average_proba(X);
72 std::vector<int> Ensemble::predict(std::vector<std::vector<int>>& X)
74 auto res = predict_proba(X);
75 return compute_arg_max(res);
77 torch::Tensor Ensemble::predict(torch::Tensor& X)
79 auto res = predict_proba(X);
80 return compute_arg_max(res);
82 torch::Tensor Ensemble::predict_average_proba(torch::Tensor& X)
84 auto n_states = models[0]->getClassNumStates();
85 torch::Tensor y_pred = torch::zeros({ X.size(1), n_states }, torch::kFloat32);
86 for (
auto i = 0; i < n_models; ++i) {
87 auto ypredict = models[i]->predict_proba(X);
89 y_pred += ypredict * significanceModels[i];
91 auto sum = std::reduce(significanceModels.begin(), significanceModels.end());
95 std::vector<std::vector<double>> Ensemble::predict_average_proba(std::vector<std::vector<int>>& X)
97 auto n_states = models[0]->getClassNumStates();
98 std::vector<std::vector<double>> y_pred(X[0].size(), std::vector<double>(n_states, 0.0));
99 for (
auto i = 0; i < n_models; ++i) {
100 auto ypredict = models[i]->predict_proba(X);
101 assert(ypredict.size() == y_pred.size());
102 assert(ypredict[0].size() == y_pred[0].size());
104 for (
auto j = 0; j < ypredict.size(); ++j) {
105 std::transform(y_pred[j].begin(), y_pred[j].end(), ypredict[j].begin(), y_pred[j].begin(),
106 [significanceModels = significanceModels[i]](
double x,
double y) {
return x + y * significanceModels; });
109 auto sum = std::reduce(significanceModels.begin(), significanceModels.end());
111 for (
auto j = 0; j < y_pred.size(); ++j) {
112 std::transform(y_pred[j].begin(), y_pred[j].end(), y_pred[j].begin(), [sum](
double x) {
return x / sum; });
116 std::vector<std::vector<double>> Ensemble::predict_average_voting(std::vector<std::vector<int>>& X)
118 torch::Tensor Xt = bayesnet::vectorToTensor(X,
false);
119 auto y_pred = predict_average_voting(Xt);
120 std::vector<std::vector<double>> result = tensorToVectorDouble(y_pred);
123 torch::Tensor Ensemble::predict_average_voting(torch::Tensor& X)
126 torch::Tensor y_pred = torch::zeros({ X.size(1), n_models }, torch::kInt32);
127 for (
auto i = 0; i < n_models; ++i) {
128 auto ypredict = models[i]->predict(X);
129 y_pred.index_put_({
"...", i }, ypredict);
131 return voting(y_pred);
133 float Ensemble::score(torch::Tensor& X, torch::Tensor& y)
135 auto y_pred = predict(X);
137 for (
int i = 0; i < y_pred.size(0); ++i) {
138 if (y_pred[i].item<int>() == y[i].item<int>()) {
142 return (
double)correct / y_pred.size(0);
144 float Ensemble::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
146 auto y_pred = predict(X);
148 for (
int i = 0; i < y_pred.size(); ++i) {
149 if (y_pred[i] == y[i]) {
153 return (
double)correct / y_pred.size();
155 std::vector<std::string> Ensemble::show()
const
157 auto result = std::vector<std::string>();
158 for (
auto i = 0; i < n_models; ++i) {
159 auto res = models[i]->show();
160 result.insert(result.end(), res.begin(), res.end());
164 std::vector<std::string> Ensemble::graph(
const std::string& title)
const
166 auto result = std::vector<std::string>();
167 for (
auto i = 0; i < n_models; ++i) {
168 auto res = models[i]->graph(title +
"_" + std::to_string(i));
169 result.insert(result.end(), res.begin(), res.end());
173 int Ensemble::getNumberOfNodes()
const
176 for (
auto i = 0; i < n_models; ++i) {
177 nodes += models[i]->getNumberOfNodes();
181 int Ensemble::getNumberOfEdges()
const
184 for (
auto i = 0; i < n_models; ++i) {
185 edges += models[i]->getNumberOfEdges();
189 int Ensemble::getNumberOfStates()
const
192 for (
auto i = 0; i < n_models; ++i) {
193 nstates += models[i]->getNumberOfStates();