BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
XSPODE.h
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#ifndef XSPODE_H
8#define XSPODE_H
9
10#include <vector>
11#include <torch/torch.h>
12#include "Classifier.h"
13#include "bayesnet/utils/CountingSemaphore.h"
14
15namespace bayesnet {
16
17 class XSpode : public Classifier {
18 public:
19 explicit XSpode(int spIndex);
20 std::vector<double> predict_proba(const std::vector<int>& instance) const;
21 std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
22 int predict(const std::vector<int>& instance) const;
23 void normalize(std::vector<double>& v) const;
24 std::string to_string() const;
25 int getNFeatures() const;
26 int getNumberOfNodes() const override;
27 int getNumberOfEdges() const override;
28 int getNumberOfStates() const override;
29 int getClassNumStates() const override;
30 std::vector<int>& getStates();
31 std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({ title }); }
32 void fitx(torch::Tensor& X, torch::Tensor& y, torch::Tensor& weights_, const Smoothing_t smoothing);
33 void setHyperparameters(const nlohmann::json& hyperparameters_) override;
34
35 //
36 // Classifier interface
37 //
38 torch::Tensor predict(torch::Tensor& X) override;
39 std::vector<int> predict(std::vector<std::vector<int>>& X) override;
40 torch::Tensor predict_proba(torch::Tensor& X) override;
41 float score(torch::Tensor& X, torch::Tensor& y) override;
42 float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
43 protected:
44 void buildModel(const torch::Tensor& weights) override;
45 void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;
46 private:
47 void addSample(const std::vector<int>& instance, double weight);
48 void computeProbabilities();
49 int superParent_;
50 int nFeatures_;
51 int statesClass_;
52 std::vector<int> states_; // [states_feat0, ..., states_feat(N-1)] (class not included in this array)
53
54 // Class counts
55 std::vector<double> classCounts_; // [c], accumulative
56 std::vector<double> classPriors_; // [c], after normalization
57
58 // For p(x_sp = spVal | c)
59 std::vector<double> spFeatureCounts_; // [spVal * statesClass_ + c]
60 std::vector<double> spFeatureProbs_; // same shape, after normalization
61
62 // For p(x_child = childVal | x_sp = spVal, c)
63 // childCounts_ is big enough to hold all child features except sp:
64 // For each child f, we store childOffsets_[f] as the start index, then
65 // childVal, spVal, c => the data.
66 std::vector<double> childCounts_;
67 std::vector<double> childProbs_;
68 std::vector<int> childOffsets_;
69
70 double alpha_ = 1.0;
71 double initializer_; // for numerical stability
72 CountingSemaphore& semaphore_;
73 };
74}
75
76#endif // XSPODE_H