BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
Classifier.h
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#ifndef CLASSIFIER_H
8#define CLASSIFIER_H
9#include <torch/torch.h>
10#include "bayesnet/utils/BayesMetrics.h"
11#include "bayesnet/BaseClassifier.h"
12
13namespace bayesnet {
14 class Classifier : public BaseClassifier {
15 public:
16 Classifier(Network model);
17 virtual ~Classifier() = default;
18 Classifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing) override;
19 Classifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing) override;
20 Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing) override;
21 Classifier& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing) override;
22 void addNodes();
23 int getNumberOfNodes() const override;
24 int getNumberOfEdges() const override;
25 int getNumberOfStates() const override;
26 int getClassNumStates() const override;
27 torch::Tensor predict(torch::Tensor& X) override;
28 std::vector<int> predict(std::vector<std::vector<int>>& X) override;
29 torch::Tensor predict_proba(torch::Tensor& X) override;
30 std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
31 status_t getStatus() const override { return status; }
32 std::string getVersion() override { return { project_version.begin(), project_version.end() }; };
33 float score(torch::Tensor& X, torch::Tensor& y) override;
34 float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
35 std::vector<std::string> show() const override;
36 std::vector<std::string> topological_order() override;
37 std::vector<std::string> getNotes() const override { return notes; }
38 std::string dump_cpt() const override;
39 void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
40 protected:
41 bool fitted;
42 unsigned int m, n; // m: number of samples, n: number of features
43 Network model;
44 Metrics metrics;
45 std::vector<std::string> features;
46 std::string className;
47 std::map<std::string, std::vector<int>> states;
48 torch::Tensor dataset; // (n+1)xm tensor
49 void checkFitParameters();
50 virtual void buildModel(const torch::Tensor& weights) = 0;
51 void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override;
52 void buildDataset(torch::Tensor& y);
53 const std::string CLASSIFIER_NOT_FITTED = "Classifier has not been fitted";
54 private:
55 Classifier& build(const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing);
56 };
57}
58#endif
59
60
61
62
63