Skip to content

Commit

Permalink
Merge pull request #1089 from borglab/fix/inference_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Feb 6, 2022
2 parents efe922b + 72772b1 commit e5e9996
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
1 change: 0 additions & 1 deletion gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ TEST(DiscreteBayesNet, Dot) {
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");

string actual = fragment.dot();
cout << actual << endl;
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
Expand Down
2 changes: 2 additions & 0 deletions python/gtsam/preamble/inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@
* Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++.
*/

#include <pybind11/stl.h>
35 changes: 35 additions & 0 deletions python/gtsam/tests/test_DiscreteBayesNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# pylint: disable=no-name-in-module, invalid-name

import unittest
import textwrap

import gtsam
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase
Expand Down Expand Up @@ -126,6 +128,39 @@ def test_fragment(self):
actual = fragment.sample(given)
self.assertEqual(len(actual), 5)

def test_dot(self):
"""Check that dot works with position hints."""
fragment = DiscreteBayesNet()
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
MyAsia = gtsam.symbol('a', 0), 2 # use a symbol!
fragment.add(Tuberculosis, [MyAsia], "99/1 95/5")
fragment.add(LungCancer, [Smoking], "99/1 90/10")

# Make sure we can *update* position hints
writer = gtsam.DotWriter()
ph: dict = writer.positionHints
ph.update({'a': 2}) # hint at symbol position
writer.positionHints = ph

# Check the output of dot
actual = fragment.dot(writer=writer)
expected_result = """\
digraph {
size="5,5";
var3[label="3"];
var4[label="4"];
var5[label="5"];
var6[label="6"];
vara0[label="a0", pos="0,2!"];
var4->var6
vara0->var3
var3->var5
var6->var5
}"""
self.assertEqual(actual, textwrap.dedent(expected_result))


if __name__ == "__main__":
unittest.main()

0 comments on commit e5e9996

Please sign in to comment.