diff --git a/gtsam/slam/expressions.h b/gtsam/slam/expressions.h index 680f2d175d..c6aa02774e 100644 --- a/gtsam/slam/expressions.h +++ b/gtsam/slam/expressions.h @@ -48,6 +48,18 @@ inline Line3_ transformTo(const Pose3_ &wTc, const Line3_ &wL) { return Line3_(f, wTc, wL); } +inline Point3_ cross(const Point3_& a, const Point3_& b) { + Point3 (*f)(const Point3 &, const Point3 &, + OptionalJacobian<3, 3>, OptionalJacobian<3, 3>) = ✗ + return Point3_(f, a, b); +} + +inline Double_ dot(const Point3_& a, const Point3_& b) { + double (*f)(const Point3 &, const Point3 &, + OptionalJacobian<1, 3>, OptionalJacobian<1, 3>) = ˙ + return Double_(f, a, b); +} + namespace internal { // define getter that returns value rather than reference inline Rot3 rotation(const Pose3& pose, OptionalJacobian<3, 6> H) { diff --git a/tests/testExpressionFactor.cpp b/tests/testExpressionFactor.cpp index e3e37e7c7d..c31baeadfe 100644 --- a/tests/testExpressionFactor.cpp +++ b/tests/testExpressionFactor.cpp @@ -727,6 +727,39 @@ TEST(ExpressionFactor, variadicTemplate) { } +TEST(ExpressionFactor, crossProduct) { + auto model = noiseModel::Isotropic::Sigma(3, 1); + + // Create expression + const auto a = Vector3_(1); + const auto b = Vector3_(2); + Vector3_ f_expr = cross(a, b); + + // Check derivatives + Values values; + values.insert(1, Vector3(0.1, 0.2, 0.3)); + values.insert(2, Vector3(0.4, 0.5, 0.6)); + ExpressionFactor factor(model, Vector3::Zero(), f_expr); + EXPECT_CORRECT_FACTOR_JACOBIANS(factor, values, 1e-5, 1e-5); +} + +TEST(ExpressionFactor, dotProduct) { + auto model = noiseModel::Isotropic::Sigma(1, 1); + + // Create expression + const auto a = Vector3_(1); + const auto b = Vector3_(2); + Double_ f_expr = dot(a, b); + + // Check derivatives + Values values; + values.insert(1, Vector3(0.1, 0.2, 0.3)); + values.insert(2, Vector3(0.4, 0.5, 0.6)); + ExpressionFactor factor(model, .0, f_expr); + EXPECT_CORRECT_FACTOR_JACOBIANS(factor, values, 1e-5, 1e-5); +} + + /* ************************************************************************* */ int main() { TestResult tr;