10 Proposal::Proposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_) : pDataset(dataset_), pFeatures(features_), pClassName(className_) {}
13 for (
auto& [key, value] : discretizers) {
17 void Proposal::checkInput(
const torch::Tensor& X,
const torch::Tensor& y)
19 if (!torch::is_floating_point(X)) {
20 throw std::invalid_argument(
"X must be a floating point tensor");
22 if (torch::is_floating_point(y)) {
23 throw std::invalid_argument(
"y must be an integer tensor");
26 map<std::string, std::vector<int>> Proposal::localDiscretizationProposal(
const map<std::string, std::vector<int>>& oldStates, Network& model)
30 auto order = model.topological_sort();
31 auto& nodes = model.getNodes();
32 map<std::string, std::vector<int>> states = oldStates;
33 std::vector<int> indicesToReDiscretize;
35 for (
auto feature : order) {
36 auto nodeParents = nodes[feature]->getParents();
37 if (nodeParents.size() < 2)
continue;
39 int index = find(pFeatures.begin(), pFeatures.end(), feature) - pFeatures.begin();
40 indicesToReDiscretize.push_back(index);
41 std::vector<std::string> parents;
42 transform(nodeParents.begin(), nodeParents.end(), back_inserter(parents), [](
const auto& p) {
return p->getName(); });
44 parents.erase(remove(parents.begin(), parents.end(), pClassName), parents.end());
46 std::vector<int> indices;
47 indices.push_back(-1);
48 transform(parents.begin(), parents.end(), back_inserter(indices), [&](
const auto& p) {
return find(pFeatures.begin(), pFeatures.end(), p) - pFeatures.begin(); });
50 std::vector<std::string> yJoinParents(Xf.size(1));
51 for (
auto idx : indices) {
52 for (
int i = 0; i < Xf.size(1); ++i) {
53 yJoinParents[i] += to_string(pDataset.index({ idx, i }).item<
int>());
56 auto yxv = factorize(yJoinParents);
57 auto xvf_ptr = Xf.index({ index }).data_ptr<float>();
58 auto xvf = std::vector<mdlp::precision_t>(xvf_ptr, xvf_ptr + Xf.size(1));
59 discretizers[feature]->fit(xvf, yxv);
63 for (
auto index : indicesToReDiscretize) {
64 auto Xt_ptr = Xf.index({ index }).data_ptr<float>();
65 auto Xt = std::vector<float>(Xt_ptr, Xt_ptr + Xf.size(1));
66 pDataset.index_put_({ index,
"..." }, torch::tensor(discretizers[pFeatures[index]]->transform(Xt)));
67 auto xStates = std::vector<int>(discretizers[pFeatures[index]]->getCutPoints().size() + 1);
68 iota(xStates.begin(), xStates.end(), 0);
70 states[pFeatures[index]] = xStates;
72 const torch::Tensor weights = torch::full({ pDataset.size(1) }, 1.0 / pDataset.size(1), torch::kDouble);
73 model.fit(pDataset, weights, pFeatures, pClassName, states, Smoothing_t::ORIGINAL);
77 map<std::string, std::vector<int>> Proposal::fit_local_discretization(
const torch::Tensor& y)
82 map<std::string, std::vector<int>> states;
83 pDataset = torch::zeros({ n + 1, m }, torch::kInt32);
84 auto yv = std::vector<int>(y.data_ptr<
int>(), y.data_ptr<
int>() + y.size(0));
86 for (
auto i = 0; i < pFeatures.size(); ++i) {
87 auto* discretizer =
new mdlp::CPPFImdlp();
88 auto Xt_ptr = Xf.index({ i }).data_ptr<float>();
89 auto Xt = std::vector<float>(Xt_ptr, Xt_ptr + Xf.size(1));
90 discretizer->fit(Xt, yv);
91 pDataset.index_put_({ i,
"..." }, torch::tensor(discretizer->transform(Xt)));
92 auto xStates = std::vector<int>(discretizer->getCutPoints().size() + 1);
93 iota(xStates.begin(), xStates.end(), 0);
94 states[pFeatures[i]] = xStates;
95 discretizers[pFeatures[i]] = discretizer;
97 int n_classes = torch::max(y).item<
int>() + 1;
98 auto yStates = std::vector<int>(n_classes);
99 iota(yStates.begin(), yStates.end(), 0);
100 states[pClassName] = yStates;
101 pDataset.index_put_({ n,
"..." }, y);
104 torch::Tensor Proposal::prepareX(torch::Tensor& X)
106 auto Xtd = torch::zeros_like(X, torch::kInt32);
107 for (
int i = 0; i < X.size(0); ++i) {
108 auto Xt = std::vector<float>(X[i].data_ptr<float>(), X[i].data_ptr<float>() + X.size(1));
109 auto Xd = discretizers[pFeatures[i]]->transform(Xt);
110 Xtd.index_put_({ i }, torch::tensor(Xd, torch::kInt32));
114 std::vector<int> Proposal::factorize(
const std::vector<std::string>& labels_t)
117 yy.reserve(labels_t.size());
118 std::map<std::string, int> labelMap;
120 for (
const std::string& label : labels_t) {
121 if (labelMap.find(label) == labelMap.end()) {
122 labelMap[label] = i++;
123 bool allDigits = std::all_of(label.begin(), label.end(), ::isdigit);
125 yy.push_back(labelMap[label]);