11 SPODE::SPODE(
int root) : Classifier(Network()), root(root)
13 validHyperparameters = {
"parent" };
16 void SPODE::setHyperparameters(
const nlohmann::json& hyperparameters_)
18 auto hyperparameters = hyperparameters_;
19 if (hyperparameters.contains(
"parent")) {
20 root = hyperparameters[
"parent"];
21 hyperparameters.erase(
"parent");
23 Classifier::setHyperparameters(hyperparameters);
25 void SPODE::buildModel(
const torch::Tensor& weights)
31 if (root >=
static_cast<int>(features.size())) {
32 throw std::invalid_argument(
"The parent node is not in the dataset");
34 for (
int i = 0; i < static_cast<int>(features.size()); ++i) {
35 model.addEdge(className, features[i]);
37 model.addEdge(features[root], features[i]);
41 std::vector<std::string> SPODE::graph(
const std::string& name)
const
43 return model.graph(name);