BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
SPODELd.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include "SPODELd.h"
8
9namespace bayesnet {
10 SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className) {}
11 SPODELd& SPODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing)
12 {
13 checkInput(X_, y_);
14 Xf = X_;
15 y = y_;
16 return commonFit(features_, className_, states_, smoothing);
17 }
18
19 SPODELd& SPODELd::fit(torch::Tensor& dataset, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing)
20 {
21 if (!torch::is_floating_point(dataset)) {
22 throw std::runtime_error("Dataset must be a floating point tensor");
23 }
24 Xf = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }).clone();
25 y = dataset.index({ -1, "..." }).clone().to(torch::kInt32);
26 return commonFit(features_, className_, states_, smoothing);
27 }
28
29 SPODELd& SPODELd::commonFit(const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_, const Smoothing_t smoothing)
30 {
31 features = features_;
32 className = className_;
33 // Fills std::vectors Xv & yv with the data from tensors X_ (discretized) & y
34 states = fit_local_discretization(y);
35 // We have discretized the input data
36 // 1st we need to fit the model to build the normal SPODE structure, SPODE::fit initializes the base Bayesian network
37 SPODE::fit(dataset, features, className, states, smoothing);
38 states = localDiscretizationProposal(states, model);
39 return *this;
40 }
41 torch::Tensor SPODELd::predict(torch::Tensor& X)
42 {
43 auto Xt = prepareX(X);
44 return SPODE::predict(Xt);
45 }
46 std::vector<std::string> SPODELd::graph(const std::string& name) const
47 {
48 return SPODE::graph(name);
49 }
50}