11 Node::Node(
const std::string& name)
19 cpTable = torch::Tensor();
23 std::string Node::getName()
const
27 void Node::addParent(Node* parent)
29 parents.push_back(parent);
31 void Node::removeParent(Node* parent)
33 parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
35 void Node::removeChild(Node* child)
37 children.erase(std::remove(children.begin(), children.end(), child), children.end());
39 void Node::addChild(Node* child)
41 children.push_back(child);
43 std::vector<Node*>& Node::getParents()
47 std::vector<Node*>& Node::getChildren()
51 int Node::getNumStates()
const
55 void Node::setNumStates(
int numStates)
57 this->numStates = numStates;
59 torch::Tensor& Node::getCPT()
70 unsigned Node::minFill()
72 std::unordered_set<std::string> neighbors;
73 for (
auto child : children) {
74 neighbors.emplace(child->getName());
76 for (
auto parent : parents) {
77 neighbors.emplace(parent->getName());
79 auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
80 return combinations(source).size();
82 std::vector<std::pair<std::string, std::string>> Node::combinations(
const std::vector<std::string>& source)
84 std::vector<std::pair<std::string, std::string>> result;
85 for (
int i = 0; i < source.size(); ++i) {
86 std::string temp = source[i];
87 for (
int j = i + 1; j < source.size(); ++j) {
88 result.push_back({ temp, source[j] });
93 void Node::computeCPT(
const torch::Tensor& dataset,
const std::vector<std::string>& features,
const double smoothing,
const torch::Tensor& weights)
96 dimensions.reserve(parents.size() + 1);
98 dimensions.push_back(numStates);
99 for (
const auto& parent : parents) {
100 dimensions.push_back(parent->getNumStates());
104 cpTable = torch::full(dimensions, smoothing, torch::kDouble);
106 std::unordered_map<std::string, int> featureIndexMap;
107 for (
size_t i = 0; i < features.size(); ++i) {
108 featureIndexMap[features[i]] = i;
112 int name_index = featureIndexMap[name];
114 std::vector<int> parent_indices;
115 parent_indices.reserve(parents.size());
116 for (
const auto& parent : parents) {
117 parent_indices.push_back(featureIndexMap[parent->getName()]);
119 c10::List<c10::optional<at::Tensor>> coordinates;
120 for (
int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
122 auto sample = dataset.index({
"...", n_sample });
123 coordinates.push_back(sample[name_index]);
124 for (
size_t i = 0; i < parent_indices.size(); ++i) {
125 coordinates.push_back(sample[parent_indices[i]]);
128 cpTable.index_put_({ coordinates }, weights.index({ n_sample }),
true);
131 cpTable /= cpTable.sum(0,
true);
133 double Node::getFactorValue(std::map<std::string, int>& evidence)
135 c10::List<c10::optional<at::Tensor>> coordinates;
137 coordinates.push_back(at::tensor(evidence[name]));
138 transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](
const auto& parent) { return at::tensor(evidence[parent->getName()]); });
139 return cpTable.index({ coordinates }).item<double>();
141 std::vector<std::string> Node::graph(
const std::string& className)
143 auto output = std::vector<std::string>();
144 auto suffix = name == className ?
", fontcolor=red, fillcolor=lightblue, style=filled " :
"";
145 output.push_back(
"\"" + name +
"\" [shape=circle" + suffix +
"] \n");
146 transform(children.begin(), children.end(), back_inserter(output), [
this](
const auto& child) {
return "\"" + name +
"\" -> \"" + child->getName() +
"\""; });