From 8a4889d97c71d4807599852d22ba6403d1c7094c Mon Sep 17 00:00:00 2001 From: Victor Reijgwart Date: Mon, 16 Sep 2024 16:57:40 +0200 Subject: [PATCH] Set interpolation mode through argument, not alternative method names --- docs/python_api/index.rst | 2 + .../queries/nearest_neighbor_interpolation.py | 6 +- .../python/queries/trilinear_interpolation.py | 6 +- library/python/src/maps.cc | 186 +++++++++--------- library/python/src/pywavemap/__init__.py | 3 +- library/python/test/test_pywavemap.py | 17 +- 6 files changed, 114 insertions(+), 106 deletions(-) diff --git a/docs/python_api/index.rst b/docs/python_api/index.rst index c77d011c8..0066b42f1 100644 --- a/docs/python_api/index.rst +++ b/docs/python_api/index.rst @@ -14,6 +14,8 @@ Python API .. autoclass:: pywavemap.HashedChunkedWaveletOctree :show-inheritance: :members: +.. autoclass:: pywavemap.InterpolationMode + :members: .. autoclass:: pywavemap.OctreeIndex :members: diff --git a/examples/python/queries/nearest_neighbor_interpolation.py b/examples/python/queries/nearest_neighbor_interpolation.py index 02a405fbd..791fb907f 100644 --- a/examples/python/queries/nearest_neighbor_interpolation.py +++ b/examples/python/queries/nearest_neighbor_interpolation.py @@ -1,4 +1,5 @@ import numpy as np +from pywavemap import InterpolationMode import _dummy_objects # Load a map @@ -8,10 +9,11 @@ query_point = np.array([0.4, 0.5, 0.6]) # Query a single point -occupancy_log_odds = your_map.interpolateNearest(query_point) +occupancy_log_odds = your_map.interpolate(query_point, + InterpolationMode.NEAREST) print(occupancy_log_odds) # Vectorized query for a list of points points = np.random.random(size=(64 * 64 * 32, 3)) -points_log_odds = your_map.interpolateNearest(points) +points_log_odds = your_map.interpolate(points, InterpolationMode.NEAREST) print(points_log_odds) diff --git a/examples/python/queries/trilinear_interpolation.py b/examples/python/queries/trilinear_interpolation.py index 78525c370..412822b4a 100644 --- a/examples/python/queries/trilinear_interpolation.py +++ b/examples/python/queries/trilinear_interpolation.py @@ -1,4 +1,5 @@ import numpy as np +from pywavemap import InterpolationMode import _dummy_objects # Load a map @@ -8,10 +9,11 @@ query_point = np.array([0.4, 0.5, 0.6]) # Query a single point -occupancy_log_odds = your_map.interpolateTrilinear(query_point) +occupancy_log_odds = your_map.interpolate(query_point, + InterpolationMode.TRILINEAR) print(occupancy_log_odds) # Vectorized query for a list of points points = np.random.random(size=(64 * 64 * 32, 3)) -points_log_odds = your_map.interpolateTrilinear(points) +points_log_odds = your_map.interpolate(points, InterpolationMode.TRILINEAR) print(points_log_odds) diff --git a/library/python/src/maps.cc b/library/python/src/maps.cc index e35ac6206..036e956bf 100644 --- a/library/python/src/maps.cc +++ b/library/python/src/maps.cc @@ -15,6 +15,14 @@ using namespace nb::literals; // NOLINT namespace wavemap { void add_map_bindings(nb::module_& m) { + enum class InterpolationMode { kNearest, kTrilinear }; + + nb::enum_(m, "InterpolationMode") + .value("NEAREST", InterpolationMode::kNearest, + "Look up the value of the nearest map cell.") + .value("TRILINEAR", InterpolationMode::kTrilinear, + "Interpolate linearly along each map axis."); + nb::class_(m, "Map", "Base class for wavemap maps.") .def_prop_ro("empty", &MapBase::empty, "Whether the map is empty.") .def_prop_ro("size", &MapBase::size, @@ -59,20 +67,21 @@ void add_map_bindings(nb::module_& m) { .def("add_to_cell_value", &MapBase::addToCellValue, "index"_a, "update"_a, "Increment the value of the map at a given index.") .def( - "interpolateNearest", - [](const MapBase& self, const Point3D& position) { - return interpolate::nearestNeighbor(self, position); - }, - "position"_a, - "Query the map's value at a point using nearest neighbor " - "interpolation.") - .def( - "interpolateTrilinear", - [](const MapBase& self, const Point3D& position) { - return interpolate::trilinear(self, position); + "interpolate", + [](const MapBase& self, const Point3D& position, + InterpolationMode mode) { + switch (mode) { + case InterpolationMode::kNearest: + return interpolate::nearestNeighbor(self, position); + case InterpolationMode::kTrilinear: + return interpolate::trilinear(self, position); + default: + throw nb::type_error("Unknown interpolation mode."); + } }, - "position"_a, - "Query the map's value at a point using trilinear interpolation.") + "position"_a, "mode"_a = InterpolationMode::kTrilinear, + "Query the map's value at a point, using the specified interpolation " + "mode.") .def_static( "create", [](const param::Value& params) -> std::shared_ptr { @@ -117,18 +126,18 @@ void add_map_bindings(nb::module_& m) { // Create nb::ndarray view for efficient access to the query indices const auto index_view = indices.view(); const auto num_queries = index_view.shape(0); - // Allocate and populate raw results array + // Create the raw results array and wrap it in a Python capsule that + // deallocates it when all references to it expire auto* results = new float[num_queries]; + nb::capsule owner(results, [](void* p) noexcept { + delete[] reinterpret_cast(p); + }); + // Compute the interpolated values for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { results[query_idx] = query_accelerator.getCellValue( {index_view(query_idx, 0), index_view(query_idx, 1), index_view(query_idx, 2)}); } - // Create Python capsule that deallocates the results array when - // all references to it expire - nb::capsule owner(results, [](void* p) noexcept { - delete[] reinterpret_cast(p); - }); // Return results as numpy array return nb::ndarray{ results, {num_queries, 1u}, owner}; @@ -146,8 +155,13 @@ void add_map_bindings(nb::module_& m) { // Create nb::ndarray view for efficient access to the query indices auto index_view = indices.view(); const auto num_queries = index_view.shape(0); - // Allocate and populate raw results array + // Create the raw results array and wrap it in a Python capsule that + // deallocates it when all references to it expire auto* results = new float[num_queries]; + nb::capsule owner(results, [](void* p) noexcept { + delete[] reinterpret_cast(p); + }); + // Compute the interpolated values for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { const OctreeIndex node_index{ index_view(query_idx, 0), @@ -155,11 +169,6 @@ void add_map_bindings(nb::module_& m) { index_view(query_idx, 3)}}; results[query_idx] = query_accelerator.getCellValue(node_index); } - // Create Python capsule that deallocates the results array when - // all references to it expire - nb::capsule owner(results, [](void* p) noexcept { - delete[] reinterpret_cast(p); - }); // Return results as numpy array return nb::ndarray{ results, {num_queries, 1u}, owner}; @@ -168,80 +177,68 @@ void add_map_bindings(nb::module_& m) { "Query the map at the given node indices, provided as a matrix with " "one (height, x, y, z) node index per row.") .def( - "interpolateNearest", - [](const MapBase& self, const Point3D& position) { - return interpolate::nearestNeighbor(self, position); - }, - "position"_a, - "Query the map's value at a point using nearest neighbor " - "interpolation.") - .def( - "interpolateTrilinear", - [](const MapBase& self, const Point3D& position) { - return interpolate::trilinear(self, position); - }, - "position"_a, - "Query the map's value at a point using trilinear interpolation.") - .def( - "interpolateNearest", - [](const HashedWaveletOctree& self, - const nb::ndarray, - nb::device::cpu>& positions) { - // Create a query accelerator - QueryAccelerator query_accelerator{self}; - // Create nb::ndarray view for efficient access to the query points - const auto positions_view = positions.view(); - const auto num_queries = positions_view.shape(0); - // Allocate and populate raw results array - auto* results = new float[num_queries]; - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - results[query_idx] = interpolate::nearestNeighbor( - query_accelerator, - {positions_view(query_idx, 0), positions_view(query_idx, 1), - positions_view(query_idx, 2)}); + "interpolate", + [](const MapBase& self, const Point3D& position, + InterpolationMode mode) { + switch (mode) { + case InterpolationMode::kNearest: + return interpolate::nearestNeighbor(self, position); + case InterpolationMode::kTrilinear: + return interpolate::trilinear(self, position); + default: + throw nb::type_error("Unknown interpolation mode."); } - // Create Python capsule that deallocates the results array when - // all references to it expire - nb::capsule owner(results, [](void* p) noexcept { - delete[] reinterpret_cast(p); - }); - // Return results as numpy array - return nb::ndarray{ - results, {num_queries, 1u}, owner}; }, - "position_list"_a, - "Query the map's value at the given points using nearest neighbor " - "interpolation.") + "position"_a, "mode"_a = InterpolationMode::kTrilinear, + "Query the map's value at a point, using the specified interpolation " + "mode.") .def( - "interpolateTrilinear", + "interpolate", [](const HashedWaveletOctree& self, const nb::ndarray, - nb::device::cpu>& positions) { + nb::device::cpu>& positions, + InterpolationMode mode) { // Create a query accelerator QueryAccelerator query_accelerator{self}; // Create nb::ndarray view for efficient access to the query points const auto positions_view = positions.view(); const auto num_queries = positions_view.shape(0); - // Allocate and populate raw results array + // Create the raw results array and wrap it in a Python capsule that + // deallocates it when all references to it expire auto* results = new float[num_queries]; - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - results[query_idx] = interpolate::trilinear( - query_accelerator, - {positions_view(query_idx, 0), positions_view(query_idx, 1), - positions_view(query_idx, 2)}); - } - // Create Python capsule that deallocates the results array when - // all references to it expire nb::capsule owner(results, [](void* p) noexcept { delete[] reinterpret_cast(p); }); + // Compute the interpolated values + switch (mode) { + case InterpolationMode::kNearest: + for (size_t query_idx = 0; query_idx < num_queries; + ++query_idx) { + results[query_idx] = interpolate::nearestNeighbor( + query_accelerator, {positions_view(query_idx, 0), + positions_view(query_idx, 1), + positions_view(query_idx, 2)}); + } + break; + case InterpolationMode::kTrilinear: + for (size_t query_idx = 0; query_idx < num_queries; + ++query_idx) { + results[query_idx] = interpolate::trilinear( + query_accelerator, {positions_view(query_idx, 0), + positions_view(query_idx, 1), + positions_view(query_idx, 2)}); + } + break; + default: + throw nb::type_error("Unknown interpolation mode."); + } // Return results as numpy array return nb::ndarray{ results, {num_queries, 1u}, owner}; }, - "position_list"_a, - "Query the map's value at the given points using trilinear " - "interpolation."); + "position_list"_a, "mode"_a = InterpolationMode::kTrilinear, + "Query the map's value at the given points, using the specified " + "interpolation mode."); nb::class_( m, "HashedChunkedWaveletOctree", @@ -254,19 +251,20 @@ void add_map_bindings(nb::module_& m) { "node_index"_a, "Query the value of the map at a given octree node index.") .def( - "interpolateNearest", - [](const MapBase& self, const Point3D& position) { - return interpolate::nearestNeighbor(self, position); - }, - "position"_a, - "Query the map's value at a point using nearest neighbor " - "interpolation.") - .def( - "interpolateTrilinear", - [](const MapBase& self, const Point3D& position) { - return interpolate::trilinear(self, position); + "interpolate", + [](const MapBase& self, const Point3D& position, + InterpolationMode mode) { + switch (mode) { + case InterpolationMode::kNearest: + return interpolate::nearestNeighbor(self, position); + case InterpolationMode::kTrilinear: + return interpolate::trilinear(self, position); + default: + throw nb::type_error("Unknown interpolation mode."); + } }, - "position"_a, - "Query the map's value at a point using trilinear interpolation."); + "position"_a, "mode"_a = InterpolationMode::kTrilinear, + "Query the map's value at a point, using the specified interpolation " + "mode."); } } // namespace wavemap diff --git a/library/python/src/pywavemap/__init__.py b/library/python/src/pywavemap/__init__.py index fd40eb74f..f1db669c3 100644 --- a/library/python/src/pywavemap/__init__.py +++ b/library/python/src/pywavemap/__init__.py @@ -6,7 +6,8 @@ from ._pywavemap_bindings import (Rotation, Pose, Pointcloud, PosedPointcloud, Image, PosedImage) from ._pywavemap_bindings import (Map, HashedWaveletOctree, - HashedChunkedWaveletOctree) + HashedChunkedWaveletOctree, + InterpolationMode) from ._pywavemap_bindings import Pipeline # Binding submodules diff --git a/library/python/test/test_pywavemap.py b/library/python/test/test_pywavemap.py index d37b8c3c7..9b06bdecc 100644 --- a/library/python/test/test_pywavemap.py +++ b/library/python/test/test_pywavemap.py @@ -29,7 +29,7 @@ def test_batched_fixed_resolution_queries(): def test_batched_multi_resolution_queries(): import numpy as np - import pywavemap + import pywavemap as wave test_map = load_test_map() @@ -38,33 +38,36 @@ def test_batched_multi_resolution_queries(): cell_indices = np.concatenate((cell_heights, cell_positions), axis=1) cell_values = test_map.get_cell_values(cell_indices) for cell_idx in range(cell_positions.shape[0]): - cell_index = pywavemap.OctreeIndex(cell_heights[cell_idx], - cell_positions[cell_idx, :]) + cell_index = wave.OctreeIndex(cell_heights[cell_idx], + cell_positions[cell_idx, :]) cell_value = test_map.get_cell_value(cell_index) assert cell_values[cell_idx] == cell_value def test_batched_nearest_neighbor_interpolation(): import numpy as np + from pywavemap import InterpolationMode test_map = load_test_map() points = np.random.random(size=(64 * 64 * 32, 3)) - points_log_odds = test_map.interpolateNearest(points) + points_log_odds = test_map.interpolate(points, InterpolationMode.NEAREST) for point_idx in range(points.shape[0]): point = points[point_idx, :] - point_log_odds = test_map.interpolateNearest(point) + point_log_odds = test_map.interpolate(point, InterpolationMode.NEAREST) assert points_log_odds[point_idx] == point_log_odds def test_batched_trilinear_interpolation(): import numpy as np + from pywavemap import InterpolationMode test_map = load_test_map() points = np.random.random(size=(64 * 64 * 32, 3)) - points_log_odds = test_map.interpolateTrilinear(points) + points_log_odds = test_map.interpolate(points, InterpolationMode.TRILINEAR) for point_idx in range(points.shape[0]): point = points[point_idx, :] - point_log_odds = test_map.interpolateTrilinear(point) + point_log_odds = test_map.interpolate(point, + InterpolationMode.TRILINEAR) assert points_log_odds[point_idx] == point_log_odds