Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor jointBayesNet #1991

Merged
merged 8 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ class DiscreteBayesTreeClique {

class DiscreteBayesTree {
DiscreteBayesTree();
void insertRoot(const gtsam::DiscreteBayesTreeClique* subtree);
void addClique(const gtsam::DiscreteBayesTreeClique* clique);
void addClique(const gtsam::DiscreteBayesTreeClique* clique, const gtsam::DiscreteBayesTreeClique* parent_clique);

void print(string s = "DiscreteBayesTree\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand All @@ -276,6 +280,12 @@ class DiscreteBayesTree {
size_t size() const;
bool empty() const;
const DiscreteBayesTreeClique* operator[](size_t j) const;
const DiscreteBayesTreeClique* clique(size_t j) const;
size_t numCachedSeparatorMarginals() const;

gtsam::DiscreteConditional marginalFactor(size_t key) const;
gtsam::DiscreteFactorGraph* joint(size_t j1, size_t j2) const;
gtsam::DiscreteBayesNet* jointBayesNet(size_t j1, size_t j2) const;

double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
Expand All @@ -285,7 +295,6 @@ class DiscreteBayesTree {
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;

string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down
33 changes: 20 additions & 13 deletions gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),

using ADT = AlgebraicDecisionTree<Key>;

// Function to construct the Asia example
DiscreteBayesNet constructAsiaExample() {
DiscreteBayesNet asia;

asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version

asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Bronchitis | Smoking = "70/30 40/60");

asia.add((Either | Tuberculosis, LungCancer) = "F T T T");

asia.add(XRay | Either = "95/5 2/98");
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");

return asia;
}

/* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet;
Expand Down Expand Up @@ -67,19 +86,7 @@ TEST(DiscreteBayesNet, bayesNet) {

/* ************************************************************************* */
TEST(DiscreteBayesNet, Asia) {
DiscreteBayesNet asia;

asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version

asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Bronchitis | Smoking = "70/30 40/60");

asia.add((Either | Tuberculosis, LungCancer) = "F T T T");

asia.add(XRay | Either = "95/5 2/98");
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
DiscreteBayesNet asia = constructAsiaExample();

// Convert to factor graph
DiscreteFactorGraph fg(asia);
Expand Down
48 changes: 43 additions & 5 deletions gtsam/discrete/tests/testDiscreteBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ TEST(DiscreteBayesTree, Shortcuts) {
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);

// calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
if (debug) {
if (debug) {
// print all shortcuts to root
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
clique.second->conditional_->printSignature();
shortcut.print("shortcut:");
}
Expand All @@ -202,6 +202,7 @@ TEST(DiscreteBayesTree, Shortcuts) {
TEST(DiscreteBayesTree, MarginalFactors) {
TestFixture self;

// Caclulate marginals with brute force enumeration.
Vector marginals = Vector::Zero(15);
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
Expand Down Expand Up @@ -287,6 +288,8 @@ TEST(DiscreteBayesTree, Joints) {
TEST(DiscreteBayesTree, Dot) {
TestFixture self;
std::string actual = self.bayesTree->dot();
// print actual:
if (debug) std::cout << actual << std::endl;
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13, 11, 6, 7\"];\n"
Expand Down Expand Up @@ -369,6 +372,41 @@ TEST(DiscreteBayesTree, Lookup) {
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
}

/* ************************************************************************* */
// Test creating a Bayes tree directly from cliques
TEST(DiscreteBayesTree, DirectFromCliques) {
// Create a BayesNet
DiscreteBayesNet bayesNet;
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
bayesNet.add(A % "1/3");
bayesNet.add(B | A = "1/3 3/1");
bayesNet.add(C | B = "3/1 3/1");

// Create cliques directly
auto clique2 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(C | B = "3/1 3/1"));
auto clique1 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(B | A = "1/3 3/1"));
auto clique0 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(A % "1/3"));

// Create a BayesTree
DiscreteBayesTree bayesTree;
bayesTree.insertRoot(clique2);
bayesTree.addClique(clique1, clique2);
bayesTree.addClique(clique0, clique1);

// Check that the BayesTree is correct
DiscreteValues values;
values[A.first] = 1;
values[B.first] = 1;
values[C.first] = 1;

// Regression
double expected = .046875;
DOUBLES_EQUAL(expected, bayesTree.evaluate(values), 1e-9);
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
155 changes: 65 additions & 90 deletions gtsam/inference/BayesTree-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <fstream>
#include <queue>
#include <cassert>
#include <unordered_set>

namespace gtsam {

/* ************************************************************************* */
Expand Down Expand Up @@ -335,112 +337,85 @@ namespace gtsam {
}

/* ************************************************************************* */
template<class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet
BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
{
// Find the lowest common ancestor of two cliques
template <class CLIQUE>
static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
// Collect all ancestors of C1
std::unordered_set<std::shared_ptr<CLIQUE>> ancestors;
for (auto p = C1; p; p = p->parent()) {
ancestors.insert(p);
}

// Find the first common ancestor in C2's lineage
std::shared_ptr<CLIQUE> B;
for (auto p = C2; p; p = p->parent()) {
if (ancestors.count(p)) {
return p; // Return the common ancestor when found
}
}

return nullptr; // Return nullptr if no common ancestor is found
}

/* ************************************************************************* */
// Given the clique P(F:S) and the ancestor clique B
// Return the Bayes tree P(S\B | S \cap B)
dellaert marked this conversation as resolved.
Show resolved Hide resolved
template <class CLIQUE>
static auto factorInto(
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,
const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
gttic(Full_root_factoring);

// Get the shortcut P(S|B)
auto p_S_B = p_F_S->shortcut(B, eliminate);

// Compute S\B
KeyVector S_setminus_B = p_F_S->separator_setminus_B(B);

// Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B)
auto [bayesTree, fg] =
typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal(
Ordering(S_setminus_B), eliminate);
return bayesTree;
}

/* ************************************************************************* */
template <class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet BayesTree<CLIQUE>::jointBayesNet(
Key j1, Key j2, const Eliminate& eliminate) const {
gttic(BayesTree_jointBayesNet);
// get clique C1 and C2
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];

gttic(Lowest_common_ancestor);
// Find lowest common ancestor clique
sharedClique B; {
// Build two paths to the root
FastList<sharedClique> path1, path2; {
sharedClique p = C1;
while(p) {
path1.push_front(p);
p = p->parent();
}
} {
sharedClique p = C2;
while(p) {
path2.push_front(p);
p = p->parent();
}
}
// Find the path intersection
typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
if(*p1 == *p2)
B = *p1;
while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
B = *p1;
++p1;
++p2;
}
}
gttoc(Lowest_common_ancestor);
// Find the lowest common ancestor clique
auto B = findLowestCommonAncestor(C1, C2);

// Build joint on all involved variables
FactorGraphType p_BC1C2;

if(B)
{
if (B) {
// Compute marginal on lowest common ancestor clique
gttic(LCA_marginal);
FactorGraphType p_B = B->marginal2(function);
gttoc(LCA_marginal);

// Compute shortcuts of the requested cliques given the lowest common ancestor
gttic(Clique_shortcuts);
BayesNetType p_C1_Bred = C1->shortcut(B, function);
BayesNetType p_C2_Bred = C2->shortcut(B, function);
gttoc(Clique_shortcuts);

// Factor the shortcuts to be conditioned on the full root
// Get the set of variables to eliminate, which is C1\B.
gttic(Full_root_factoring);
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
KeyVector C1_minus_B; {
KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
for(const Key j: *B->conditional()) {
C1_minus_B_set.erase(j); }
C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
}
// Factor into C1\B | B.
p_C1_B =
FactorGraphType(p_C1_Bred)
.eliminatePartialMultifrontal(Ordering(C1_minus_B), function)
.first;
}
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
KeyVector C2_minus_B; {
KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
for(const Key j: *B->conditional()) {
C2_minus_B_set.erase(j); }
C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
}
// Factor into C2\B | B.
p_C2_B =
FactorGraphType(p_C2_Bred)
.eliminatePartialMultifrontal(Ordering(C2_minus_B), function)
.first;
}
gttoc(Full_root_factoring);
FactorGraphType p_B = B->marginal2(eliminate);

// Factor the shortcuts to be conditioned on lowest common ancestor
auto p_C1_B = factorInto(C1, B, eliminate);
auto p_C2_B = factorInto(C2, B, eliminate);

gttic(Variable_joint);
p_BC1C2.push_back(p_B);
p_BC1C2.push_back(*p_C1_B);
p_BC1C2.push_back(*p_C2_B);
if(C1 != B)
p_BC1C2.push_back(C1->conditional());
if(C2 != B)
p_BC1C2.push_back(C2->conditional());
gttoc(Variable_joint);
}
else
{
// The nodes have no common ancestor, they're in different trees, so they're joint is just the
// product of their marginals.
gttic(Disjoint_marginals);
p_BC1C2.push_back(C1->marginal2(function));
p_BC1C2.push_back(C2->marginal2(function));
gttoc(Disjoint_marginals);
if (C1 != B) p_BC1C2.push_back(C1->conditional());
if (C2 != B) p_BC1C2.push_back(C2->conditional());
} else {
// The nodes have no common ancestor, they're in different trees, so
// they're joint is just the product of their marginals.
p_BC1C2.push_back(C1->marginal2(eliminate));
p_BC1C2.push_back(C2->marginal2(eliminate));
}

// now, marginalize out everything that is not variable j1 or j2
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function);
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate);
}

/* ************************************************************************* */
Expand Down
22 changes: 12 additions & 10 deletions gtsam/inference/BayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ namespace gtsam {
/** Assignment operator */
This& operator=(const This& other);

public:

/// @name Testable
/// @{

/** check equality */
bool equals(const This& other, double tol = 1e-9) const;

public:
/** print */
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
Expand Down Expand Up @@ -185,18 +186,19 @@ namespace gtsam {
*/
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;

/// @name Graph Display
/// @{
/// @}
/// @name Graph Display
/// @{

/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// output to file with graphviz format.
void saveGraph(const std::string& filename,
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}
Expand Down
Loading
Loading