10 TAN::TAN() : Classifier(Network())
12 validHyperparameters = {
"parent" };
15 void TAN::setHyperparameters(
const nlohmann::json& hyperparameters_)
17 auto hyperparameters = hyperparameters_;
18 if (hyperparameters.contains(
"parent")) {
19 parent = hyperparameters[
"parent"];
20 hyperparameters.erase(
"parent");
22 Classifier::setHyperparameters(hyperparameters);
24 void TAN::buildModel(
const torch::Tensor& weights)
30 auto mi = std::vector <std::pair<int, float >>();
31 torch::Tensor class_dataset = dataset.index({ -1,
"..." });
32 for (
int i = 0; i < static_cast<int>(features.size()); ++i) {
33 torch::Tensor feature_dataset = dataset.index({ i,
"..." });
34 auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset, weights);
35 mi.push_back({ i, mi_value });
37 sort(mi.begin(), mi.end(), [](
const auto& left,
const auto& right) {return left.second < right.second;});
38 auto root = parent == -1 ? mi[mi.size() - 1].first : parent;
39 if (root >=
static_cast<int>(features.size())) {
40 throw std::invalid_argument(
"The parent node is not in the dataset");
43 auto weights_matrix = metrics.conditionalEdge(weights);
45 auto mst = metrics.maximumSpanningTree(features, weights_matrix, root);
47 for (
auto i = 0; i < mst.size(); ++i) {
48 auto [from, to] = mst[i];
49 model.addEdge(features[from], features[to]);
52 for (
auto feature : features) {
53 model.addEdge(className, feature);
56 std::vector<std::string> TAN::graph(
const std::string& title)
const
58 return model.graph(title);