Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrapping for hybrid nonlinear #1281

Merged
merged 1 commit into from
Aug 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ void HybridNonlinearFactorGraph::add(
FactorGraph::add(boost::make_shared<HybridNonlinearFactor>(factor));
}

/* ************************************************************************* */
void HybridNonlinearFactorGraph::add(
boost::shared_ptr<DiscreteFactor> factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}

/* ************************************************************************* */
void HybridNonlinearFactorGraph::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
Expand Down
3 changes: 3 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
/// Add a nonlinear factor as a shared ptr.
void add(boost::shared_ptr<NonlinearFactor> factor);

/// Add a discrete factor as a shared ptr.
void add(boost::shared_ptr<DiscreteFactor> factor);

/// Print the factor graph.
void print(
const std::string& s = "HybridNonlinearFactorGraph",
Expand Down
59 changes: 58 additions & 1 deletion gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,17 @@ virtual class HybridConditional {
bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const;
size_t nrFrontals() const;
size_t nrParents() const;
Factor* inner();
gtsam::Factor* inner();
};

#include <gtsam/hybrid/HybridDiscreteFactor.h>
virtual class HybridDiscreteFactor {
HybridDiscreteFactor(gtsam::DecisionTreeFactor dtf);
void print(string s = "HybridDiscreteFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const;
gtsam::Factor* inner();
};

#include <gtsam/hybrid/GaussianMixtureFactor.h>
Expand Down Expand Up @@ -132,6 +142,7 @@ class HybridGaussianFactorGraph {
void add(gtsam::JacobianFactor* factor);

bool empty() const;
void remove(size_t i);
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::HybridFactor* at(size_t i) const;
Expand Down Expand Up @@ -159,4 +170,50 @@ class HybridGaussianFactorGraph {
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
};

#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
class HybridNonlinearFactorGraph {
HybridNonlinearFactorGraph();
HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph);
void push_back(gtsam::HybridFactor* factor);
void push_back(gtsam::NonlinearFactor* factor);
void push_back(gtsam::HybridDiscreteFactor* factor);
void add(gtsam::NonlinearFactor* factor);
void add(gtsam::DiscreteFactor* factor);
gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const;

bool empty() const;
void remove(size_t i);
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::HybridFactor* at(size_t i) const;

void print(string s = "HybridNonlinearFactorGraph\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/hybrid/MixtureFactor.h>
class MixtureFactor : gtsam::HybridFactor {
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors, bool normalized = false);

template <FACTOR = {gtsam::NonlinearFactor}>
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const std::vector<FACTOR*>& factors,
bool normalized = false);

double error(const gtsam::Values& continuousVals,
const gtsam::DiscreteValues& discreteVals) const;

double nonlinearFactorLogNormalizingConstant(const gtsam::NonlinearFactor* factor,
const gtsam::Values& values) const;

GaussianMixtureFactor* linearize(
const gtsam::Values& continuousVals) const;

void print(string s = "MixtureFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

} // namespace gtsam
7 changes: 7 additions & 0 deletions python/gtsam/preamble/hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,12 @@
* Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++.
*/
#include <pybind11/stl.h>

#ifdef GTSAM_ALLOCATOR_TBB
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::Key, tbb::tbb_allocator<gtsam::Key>>);
#else
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::Key>);
#endif

PYBIND11_MAKE_OPAQUE(std::vector<gtsam::GaussianFactor::shared_ptr>);
55 changes: 55 additions & 0 deletions python/gtsam/tests/test_HybridNonlinearFactorGraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved

See LICENSE for the license information

Unit tests for Hybrid Nonlinear Factor Graphs.
Author: Fan Jiang
"""
# pylint: disable=invalid-name, no-name-in-module, no-member

from __future__ import print_function

import unittest

import gtsam
import numpy as np
from gtsam.symbol_shorthand import C, X
from gtsam.utils.test_case import GtsamTestCase


class TestHybridGaussianFactorGraph(GtsamTestCase):
"""Unit tests for HybridGaussianFactorGraph."""

def test_nonlinear_hybrid(self):
nlfg = gtsam.HybridNonlinearFactorGraph()
dk = gtsam.DiscreteKeys()
dk.push_back((10, 2))
nlfg.add(gtsam.BetweenFactorPoint3(1, 2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([1, 1, 1])))
nlfg.add(
gtsam.PriorFactorPoint3(2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([0.5, 0.5, 0.5])))
nlfg.push_back(
gtsam.MixtureFactor([1], dk, [
gtsam.PriorFactorPoint3(1, gtsam.Point3(0, 0, 0),
gtsam.noiseModel.Unit.Create(3)),
gtsam.PriorFactorPoint3(1, gtsam.Point3(1, 2, 1),
gtsam.noiseModel.Unit.Create(3))
]))
nlfg.add(gtsam.DecisionTreeFactor((10, 2), "1 3"))
values = gtsam.Values()
values.insert_point3(1, gtsam.Point3(0, 0, 0))
values.insert_point3(2, gtsam.Point3(2, 3, 1))
hfg = nlfg.linearize(values)
o = gtsam.Ordering()
o.push_back(1)
o.push_back(2)
o.push_back(10)
hbn = hfg.eliminateSequential(o)
hbv = hbn.optimize()
self.assertEqual(hbv.atDiscrete(10), 0)


if __name__ == "__main__":
unittest.main()