diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 40f1822cf2..b84ac69a0d 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -268,6 +268,10 @@ class DiscreteBayesTreeClique { class DiscreteBayesTree { DiscreteBayesTree(); + void insertRoot(const gtsam::DiscreteBayesTreeClique* subtree); + void addClique(const gtsam::DiscreteBayesTreeClique* clique); + void addClique(const gtsam::DiscreteBayesTreeClique* clique, const gtsam::DiscreteBayesTreeClique* parent_clique); + void print(string s = "DiscreteBayesTree\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -276,6 +280,12 @@ class DiscreteBayesTree { size_t size() const; bool empty() const; const DiscreteBayesTreeClique* operator[](size_t j) const; + const DiscreteBayesTreeClique* clique(size_t j) const; + size_t numCachedSeparatorMarginals() const; + + gtsam::DiscreteConditional marginalFactor(size_t key) const; + gtsam::DiscreteFactorGraph* joint(size_t j1, size_t j2) const; + gtsam::DiscreteBayesNet* jointBayesNet(size_t j1, size_t j2) const; double evaluate(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const; @@ -285,7 +295,6 @@ class DiscreteBayesTree { void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; - double operator()(const gtsam::DiscreteValues& values) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 49a360cbb6..d2033909c5 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -36,6 +36,25 @@ static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), using ADT = AlgebraicDecisionTree; +// Function to construct the Asia example +DiscreteBayesNet constructAsiaExample() { + DiscreteBayesNet asia; + + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version + + asia.add(Tuberculosis | Asia = "99/1 95/5"); + asia.add(LungCancer | Smoking = "99/1 90/10"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); + + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + + return asia; +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { DiscreteBayesNet bayesNet; @@ -67,19 +86,7 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { - DiscreteBayesNet asia; - - asia.add(Asia, "99/1"); - asia.add(Smoking % "50/50"); // Signature version - - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - asia.add(XRay | Either = "95/5 2/98"); - asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + DiscreteBayesNet asia = constructAsiaExample(); // Convert to factor graph DiscreteFactorGraph fg(asia); diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 617eb7c9d5..e0402969dd 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -186,11 +186,11 @@ TEST(DiscreteBayesTree, Shortcuts) { shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); - // calculate all shortcuts to root - DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); - for (auto clique : cliques) { - DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); - if (debug) { + if (debug) { + // print all shortcuts to root + DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); + for (auto clique : cliques) { + DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); clique.second->conditional_->printSignature(); shortcut.print("shortcut:"); } @@ -202,6 +202,7 @@ TEST(DiscreteBayesTree, Shortcuts) { TEST(DiscreteBayesTree, MarginalFactors) { TestFixture self; + // Caclulate marginals with brute force enumeration. Vector marginals = Vector::Zero(15); for (size_t i = 0; i < self.assignments.size(); ++i) { DiscreteValues& x = self.assignments[i]; @@ -287,6 +288,8 @@ TEST(DiscreteBayesTree, Joints) { TEST(DiscreteBayesTree, Dot) { TestFixture self; std::string actual = self.bayesTree->dot(); + // print actual: + if (debug) std::cout << actual << std::endl; EXPECT(actual == "digraph G{\n" "0[label=\"13, 11, 6, 7\"];\n" @@ -369,6 +372,41 @@ TEST(DiscreteBayesTree, Lookup) { EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9); } +/* ************************************************************************* */ +// Test creating a Bayes tree directly from cliques +TEST(DiscreteBayesTree, DirectFromCliques) { + // Create a BayesNet + DiscreteBayesNet bayesNet; + DiscreteKey A(0, 2), B(1, 2), C(2, 2); + bayesNet.add(A % "1/3"); + bayesNet.add(B | A = "1/3 3/1"); + bayesNet.add(C | B = "3/1 3/1"); + + // Create cliques directly + auto clique2 = std::make_shared( + std::make_shared(C | B = "3/1 3/1")); + auto clique1 = std::make_shared( + std::make_shared(B | A = "1/3 3/1")); + auto clique0 = std::make_shared( + std::make_shared(A % "1/3")); + + // Create a BayesTree + DiscreteBayesTree bayesTree; + bayesTree.insertRoot(clique2); + bayesTree.addClique(clique1, clique2); + bayesTree.addClique(clique0, clique1); + + // Check that the BayesTree is correct + DiscreteValues values; + values[A.first] = 1; + values[B.first] = 1; + values[C.first] = 1; + + // Regression + double expected = .046875; + DOUBLES_EQUAL(expected, bayesTree.evaluate(values), 1e-9); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index 0648a90f64..c65e2ddc26 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -28,6 +28,8 @@ #include #include #include +#include + namespace gtsam { /* ************************************************************************* */ @@ -335,112 +337,85 @@ namespace gtsam { } /* ************************************************************************* */ - template - typename BayesTree::sharedBayesNet - BayesTree::jointBayesNet(Key j1, Key j2, const Eliminate& function) const - { + // Find the lowest common ancestor of two cliques + template + static std::shared_ptr findLowestCommonAncestor( + const std::shared_ptr& C1, const std::shared_ptr& C2) { + // Collect all ancestors of C1 + std::unordered_set> ancestors; + for (auto p = C1; p; p = p->parent()) { + ancestors.insert(p); + } + + // Find the first common ancestor in C2's lineage + std::shared_ptr B; + for (auto p = C2; p; p = p->parent()) { + if (ancestors.count(p)) { + return p; // Return the common ancestor when found + } + } + + return nullptr; // Return nullptr if no common ancestor is found + } + + /* ************************************************************************* */ + // Given the clique P(F:S) and the ancestor clique B + // Return the Bayes tree P(S\B | S \cap B) + template + static auto factorInto( + const std::shared_ptr& p_F_S, const std::shared_ptr& B, + const typename CLIQUE::FactorGraphType::Eliminate& eliminate) { + gttic(Full_root_factoring); + + // Get the shortcut P(S|B) + auto p_S_B = p_F_S->shortcut(B, eliminate); + + // Compute S\B + KeyVector S_setminus_B = p_F_S->separator_setminus_B(B); + + // Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B) + auto [bayesTree, fg] = + typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal( + Ordering(S_setminus_B), eliminate); + return bayesTree; + } + + /* ************************************************************************* */ + template + typename BayesTree::sharedBayesNet BayesTree::jointBayesNet( + Key j1, Key j2, const Eliminate& eliminate) const { gttic(BayesTree_jointBayesNet); // get clique C1 and C2 sharedClique C1 = (*this)[j1], C2 = (*this)[j2]; - gttic(Lowest_common_ancestor); - // Find lowest common ancestor clique - sharedClique B; { - // Build two paths to the root - FastList path1, path2; { - sharedClique p = C1; - while(p) { - path1.push_front(p); - p = p->parent(); - } - } { - sharedClique p = C2; - while(p) { - path2.push_front(p); - p = p->parent(); - } - } - // Find the path intersection - typename FastList::const_iterator p1 = path1.begin(), p2 = path2.begin(); - if(*p1 == *p2) - B = *p1; - while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) { - B = *p1; - ++p1; - ++p2; - } - } - gttoc(Lowest_common_ancestor); + // Find the lowest common ancestor clique + auto B = findLowestCommonAncestor(C1, C2); // Build joint on all involved variables FactorGraphType p_BC1C2; - if(B) - { + if (B) { // Compute marginal on lowest common ancestor clique - gttic(LCA_marginal); - FactorGraphType p_B = B->marginal2(function); - gttoc(LCA_marginal); - - // Compute shortcuts of the requested cliques given the lowest common ancestor - gttic(Clique_shortcuts); - BayesNetType p_C1_Bred = C1->shortcut(B, function); - BayesNetType p_C2_Bred = C2->shortcut(B, function); - gttoc(Clique_shortcuts); - - // Factor the shortcuts to be conditioned on the full root - // Get the set of variables to eliminate, which is C1\B. - gttic(Full_root_factoring); - std::shared_ptr p_C1_B; { - KeyVector C1_minus_B; { - KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents()); - for(const Key j: *B->conditional()) { - C1_minus_B_set.erase(j); } - C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end()); - } - // Factor into C1\B | B. - p_C1_B = - FactorGraphType(p_C1_Bred) - .eliminatePartialMultifrontal(Ordering(C1_minus_B), function) - .first; - } - std::shared_ptr p_C2_B; { - KeyVector C2_minus_B; { - KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents()); - for(const Key j: *B->conditional()) { - C2_minus_B_set.erase(j); } - C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end()); - } - // Factor into C2\B | B. - p_C2_B = - FactorGraphType(p_C2_Bred) - .eliminatePartialMultifrontal(Ordering(C2_minus_B), function) - .first; - } - gttoc(Full_root_factoring); + FactorGraphType p_B = B->marginal2(eliminate); + + // Factor the shortcuts to be conditioned on lowest common ancestor + auto p_C1_B = factorInto(C1, B, eliminate); + auto p_C2_B = factorInto(C2, B, eliminate); - gttic(Variable_joint); p_BC1C2.push_back(p_B); p_BC1C2.push_back(*p_C1_B); p_BC1C2.push_back(*p_C2_B); - if(C1 != B) - p_BC1C2.push_back(C1->conditional()); - if(C2 != B) - p_BC1C2.push_back(C2->conditional()); - gttoc(Variable_joint); - } - else - { - // The nodes have no common ancestor, they're in different trees, so they're joint is just the - // product of their marginals. - gttic(Disjoint_marginals); - p_BC1C2.push_back(C1->marginal2(function)); - p_BC1C2.push_back(C2->marginal2(function)); - gttoc(Disjoint_marginals); + if (C1 != B) p_BC1C2.push_back(C1->conditional()); + if (C2 != B) p_BC1C2.push_back(C2->conditional()); + } else { + // The nodes have no common ancestor, they're in different trees, so + // they're joint is just the product of their marginals. + p_BC1C2.push_back(C1->marginal2(eliminate)); + p_BC1C2.push_back(C2->marginal2(eliminate)); } // now, marginalize out everything that is not variable j1 or j2 - return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function); + return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate); } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 03f79c8cf1..4a2ae7560f 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -119,13 +119,14 @@ namespace gtsam { /** Assignment operator */ This& operator=(const This& other); + public: + /// @name Testable /// @{ /** check equality */ bool equals(const This& other, double tol = 1e-9) const; - public: /** print */ void print(const std::string& s = "", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; @@ -185,18 +186,19 @@ namespace gtsam { */ sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; - /// @name Graph Display - /// @{ + /// @} + /// @name Graph Display + /// @{ - /// Output to graphviz format, stream version. - void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - /// Output to graphviz format string. - std::string dot( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - /// output to file with graphviz format. - void saveGraph(const std::string& filename, + /// output to file with graphviz format. + void saveGraph(const std::string& filename, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; /// @} diff --git a/gtsam/inference/BayesTreeCliqueBase-inst.h b/gtsam/inference/BayesTreeCliqueBase-inst.h index a91fa4f78b..9e687be6b6 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inst.h +++ b/gtsam/inference/BayesTreeCliqueBase-inst.h @@ -104,14 +104,16 @@ namespace gtsam { } /* ************************************************************************* */ - // The shortcut density is a conditional P(S|R) of the separator of this - // clique on the root. We can compute it recursively from the parent shortcut - // P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p - /* ************************************************************************* */ - template + // The shortcut density is a conditional P(S|B) of the separator of this + // clique on the root or common ancestor B. We can compute it recursively from + // the parent shortcut P(Sp|B) as \int P(Fp|Sp) P(Sp|B), where Fp are the + // frontal nodes in p + /* ************************************************************************* + */ + template typename BayesTreeCliqueBase::BayesNetType - BayesTreeCliqueBase::shortcut(const derived_ptr& B, Eliminate function) const - { + BayesTreeCliqueBase::shortcut( + const derived_ptr& B, Eliminate function) const { gttic(BayesTreeCliqueBase_shortcut); // We only calculate the shortcut when this clique is not B // and when the S\B is not empty @@ -120,12 +122,10 @@ namespace gtsam { { // Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph derived_ptr parent(parent_.lock()); - gttoc(BayesTreeCliqueBase_shortcut); FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B) - gttic(BayesTreeCliqueBase_shortcut); p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp) - // Determine the variables we want to keepSet, S union B + // Determine the variables we want to keep, S union B KeyVector keep = shortcut_indices(B, p_Cp_B); // Marginalize out everything except S union B @@ -139,8 +139,9 @@ namespace gtsam { } /* *********************************************************************** */ - // separator marginal, uses separator marginal of parent recursively - // P(C) = P(F|S) P(S) + // Separator marginal, uses separator marginal of parent recursively + // Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) + // if P(Sp) is not cached, it will call separatorMarginal on the parent /* *********************************************************************** */ template typename BayesTreeCliqueBase::FactorGraphType @@ -150,30 +151,22 @@ namespace gtsam { gttic(BayesTreeCliqueBase_separatorMarginal); // Check if the Separator marginal was already calculated if (!cachedSeparatorMarginal_) { - gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); - // If this is the root, there is no separator if (parent_.expired() /*(if we're the root)*/) { // we are root, return empty FactorGraphType empty; cachedSeparatorMarginal_ = empty; } else { - // Flatten recursion in timing outline - gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss); - gttoc(BayesTreeCliqueBase_separatorMarginal); - // Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) // initialize P(Cp) with the parent separator marginal derived_ptr parent(parent_.lock()); - FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp) - - gttic(BayesTreeCliqueBase_separatorMarginal); - gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); + FactorGraphType p_Cp( + parent->separatorMarginal(function)); // recursive P(Sp) // now add the parent conditional p_Cp.push_back(parent->conditional_); // P(Fp|Sp) - // The variables we want to keepSet are exactly the ones in S + // The variables we want to keep are exactly the ones in S KeyVector indicesS(this->conditional()->beginParents(), this->conditional()->endParents()); auto separatorMarginal = diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index 0ccb04e908..c674fb13a5 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -190,11 +190,11 @@ namespace gtsam { friend class BayesTree; - protected: - /// Calculate set \f$ S \setminus B \f$ for shortcut calculations KeyVector separator_setminus_B(const derived_ptr& B) const; + protected: + /** Determine variable indices to keep in recursive separator shortcut calculation The factor * graph p_Cp_B has keys from the parent clique Cp and from B. But we only keep the variables * not in S union B. */ diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index e08491faba..e8943fc803 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -156,5 +156,36 @@ def test_discrete_bayes_tree_lookup(self): values[X(3)] = 2 self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10... + def test_direct_from_cliques(self): + """Test creating a Bayes tree directly from cliques.""" + # Create a BayesNet + bayesNet = DiscreteBayesNet() + A, B, C = (0, 2), (1, 2), (2, 2) + bayesNet.add(A, "1/3") + bayesNet.add(B, [A], "1/3 3/1") + bayesNet.add(C, [B], "3/1 3/1") + + # Create cliques directly + clique2 = DiscreteBayesTreeClique(DiscreteConditional(C, [B], "3/1 3/1")) + clique1 = DiscreteBayesTreeClique(DiscreteConditional(B, [A], "1/3 3/1")) + clique0 = DiscreteBayesTreeClique(DiscreteConditional(A, "1/3")) + + # Create a BayesTree + bayesTree = gtsam.DiscreteBayesTree() + bayesTree.insertRoot(clique2) + bayesTree.addClique(clique1, clique2) + bayesTree.addClique(clique0, clique1) + + # Check that the BayesTree is correct + values = DiscreteValues() + values[0] = 1 + values[1] = 1 + values[2] = 1 + + # regression + expected = .046875 + self.assertAlmostEqual(expected, bayesNet.evaluate(values)) + + if __name__ == "__main__": unittest.main()