6#include "bayesnet/utils/bayesnetUtils.h"
10 KDB::KDB(
int k,
float theta) : Classifier(Network()), k(k), theta(theta)
12 validHyperparameters = {
"k",
"theta" };
15 void KDB::setHyperparameters(
const nlohmann::json& hyperparameters_)
17 auto hyperparameters = hyperparameters_;
18 if (hyperparameters.contains(
"k")) {
19 k = hyperparameters[
"k"];
20 hyperparameters.erase(
"k");
22 if (hyperparameters.contains(
"theta")) {
23 theta = hyperparameters[
"theta"];
24 hyperparameters.erase(
"theta");
26 Classifier::setHyperparameters(hyperparameters);
28 void KDB::buildModel(
const torch::Tensor& weights)
52 const torch::Tensor& y = dataset.index({ -1,
"..." });
53 std::vector<double> mi;
54 for (
auto i = 0; i < features.size(); i++) {
55 torch::Tensor firstFeature = dataset.index({ i,
"..." });
56 mi.push_back(metrics.mutualInformation(firstFeature, y, weights));
59 auto conditionalEdgeWeights = metrics.conditionalEdge(weights);
67 auto order = argsort(mi);
68 for (
auto idx : order) {
71 model.addEdge(className, features[idx]);
74 add_m_edges(idx, S, conditionalEdgeWeights);
79 void KDB::add_m_edges(
int idx, std::vector<int>& S, torch::Tensor& weights)
81 auto n_edges = std::min(k,
static_cast<int>(S.size()));
82 auto cond_w = clone(weights);
83 bool exit_cond = k == 0;
86 auto max_minfo = argmax(cond_w.index({ idx,
"..." })).item<
int>();
87 auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
88 if (belongs && cond_w.index({ idx, max_minfo }).item<
float>() > theta) {
90 model.addEdge(features[max_minfo], features[idx]);
93 catch (
const std::invalid_argument& e) {
97 cond_w.index_put_({ idx, max_minfo }, -1);
98 auto candidates_mask = cond_w.index({ idx,
"..." }).gt(theta);
99 auto candidates = candidates_mask.nonzero();
100 exit_cond = num == n_edges || candidates.size(0) == 0;
103 std::vector<std::string> KDB::graph(
const std::string& title)
const
105 std::string header{ title };
106 if (title ==
"KDB") {
107 header +=
" (k=" + std::to_string(k) +
", theta=" + std::to_string(theta) +
")";
109 return model.graph(header);