BayesNet 1.0.7.
Bayesian Network and basic classifiers Library.
Loading...
Searching...
No Matches
SPODE.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include "SPODE.h"
8
9namespace bayesnet {
10
11 SPODE::SPODE(int root) : Classifier(Network()), root(root)
12 {
13 validHyperparameters = { "parent" };
14 }
15
16 void SPODE::setHyperparameters(const nlohmann::json& hyperparameters_)
17 {
18 auto hyperparameters = hyperparameters_;
19 if (hyperparameters.contains("parent")) {
20 root = hyperparameters["parent"];
21 hyperparameters.erase("parent");
22 }
23 Classifier::setHyperparameters(hyperparameters);
24 }
25 void SPODE::buildModel(const torch::Tensor& weights)
26 {
27 // 0. Add all nodes to the model
28 addNodes();
29 // 1. Add edges from the class node to all other nodes
30 // 2. Add edges from the root node to all other nodes
31 if (root >= static_cast<int>(features.size())) {
32 throw std::invalid_argument("The parent node is not in the dataset");
33 }
34 for (int i = 0; i < static_cast<int>(features.size()); ++i) {
35 model.addEdge(className, features[i]);
36 if (i != root) {
37 model.addEdge(features[root], features[i]);
38 }
39 }
40 }
41 std::vector<std::string> SPODE::graph(const std::string& name) const
42 {
43 return model.graph(name);
44 }
45
46}