Skip to content

Commit

Permalink
Set interpolation mode through argument, not alternative method names
Browse files Browse the repository at this point in the history
  • Loading branch information
victorreijgwart committed Sep 16, 2024
1 parent 450d9ae commit 8a4889d
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 106 deletions.
2 changes: 2 additions & 0 deletions docs/python_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Python API
.. autoclass:: pywavemap.HashedChunkedWaveletOctree
:show-inheritance:
:members:
.. autoclass:: pywavemap.InterpolationMode
:members:

.. autoclass:: pywavemap.OctreeIndex
:members:
Expand Down
6 changes: 4 additions & 2 deletions examples/python/queries/nearest_neighbor_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from pywavemap import InterpolationMode
import _dummy_objects

# Load a map
Expand All @@ -8,10 +9,11 @@
query_point = np.array([0.4, 0.5, 0.6])

# Query a single point
occupancy_log_odds = your_map.interpolateNearest(query_point)
occupancy_log_odds = your_map.interpolate(query_point,
InterpolationMode.NEAREST)
print(occupancy_log_odds)

# Vectorized query for a list of points
points = np.random.random(size=(64 * 64 * 32, 3))
points_log_odds = your_map.interpolateNearest(points)
points_log_odds = your_map.interpolate(points, InterpolationMode.NEAREST)
print(points_log_odds)
6 changes: 4 additions & 2 deletions examples/python/queries/trilinear_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from pywavemap import InterpolationMode
import _dummy_objects

# Load a map
Expand All @@ -8,10 +9,11 @@
query_point = np.array([0.4, 0.5, 0.6])

# Query a single point
occupancy_log_odds = your_map.interpolateTrilinear(query_point)
occupancy_log_odds = your_map.interpolate(query_point,
InterpolationMode.TRILINEAR)
print(occupancy_log_odds)

# Vectorized query for a list of points
points = np.random.random(size=(64 * 64 * 32, 3))
points_log_odds = your_map.interpolateTrilinear(points)
points_log_odds = your_map.interpolate(points, InterpolationMode.TRILINEAR)
print(points_log_odds)
186 changes: 92 additions & 94 deletions library/python/src/maps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ using namespace nb::literals; // NOLINT

namespace wavemap {
void add_map_bindings(nb::module_& m) {
enum class InterpolationMode { kNearest, kTrilinear };

nb::enum_<InterpolationMode>(m, "InterpolationMode")
.value("NEAREST", InterpolationMode::kNearest,
"Look up the value of the nearest map cell.")
.value("TRILINEAR", InterpolationMode::kTrilinear,
"Interpolate linearly along each map axis.");

nb::class_<MapBase>(m, "Map", "Base class for wavemap maps.")
.def_prop_ro("empty", &MapBase::empty, "Whether the map is empty.")
.def_prop_ro("size", &MapBase::size,
Expand Down Expand Up @@ -59,20 +67,21 @@ void add_map_bindings(nb::module_& m) {
.def("add_to_cell_value", &MapBase::addToCellValue, "index"_a, "update"_a,
"Increment the value of the map at a given index.")
.def(
"interpolateNearest",
[](const MapBase& self, const Point3D& position) {
return interpolate::nearestNeighbor(self, position);
},
"position"_a,
"Query the map's value at a point using nearest neighbor "
"interpolation.")
.def(
"interpolateTrilinear",
[](const MapBase& self, const Point3D& position) {
return interpolate::trilinear(self, position);
"interpolate",
[](const MapBase& self, const Point3D& position,
InterpolationMode mode) {
switch (mode) {
case InterpolationMode::kNearest:
return interpolate::nearestNeighbor(self, position);
case InterpolationMode::kTrilinear:
return interpolate::trilinear(self, position);
default:
throw nb::type_error("Unknown interpolation mode.");
}
},
"position"_a,
"Query the map's value at a point using trilinear interpolation.")
"position"_a, "mode"_a = InterpolationMode::kTrilinear,
"Query the map's value at a point, using the specified interpolation "
"mode.")
.def_static(
"create",
[](const param::Value& params) -> std::shared_ptr<MapBase> {
Expand Down Expand Up @@ -117,18 +126,18 @@ void add_map_bindings(nb::module_& m) {
// Create nb::ndarray view for efficient access to the query indices
const auto index_view = indices.view();
const auto num_queries = index_view.shape(0);
// Allocate and populate raw results array
// Create the raw results array and wrap it in a Python capsule that
// deallocates it when all references to it expire
auto* results = new float[num_queries];
nb::capsule owner(results, [](void* p) noexcept {
delete[] reinterpret_cast<float*>(p);
});
// Compute the interpolated values
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
results[query_idx] = query_accelerator.getCellValue(
{index_view(query_idx, 0), index_view(query_idx, 1),
index_view(query_idx, 2)});
}
// Create Python capsule that deallocates the results array when
// all references to it expire
nb::capsule owner(results, [](void* p) noexcept {
delete[] reinterpret_cast<float*>(p);
});
// Return results as numpy array
return nb::ndarray<nb::numpy, float>{
results, {num_queries, 1u}, owner};
Expand All @@ -146,20 +155,20 @@ void add_map_bindings(nb::module_& m) {
// Create nb::ndarray view for efficient access to the query indices
auto index_view = indices.view();
const auto num_queries = index_view.shape(0);
// Allocate and populate raw results array
// Create the raw results array and wrap it in a Python capsule that
// deallocates it when all references to it expire
auto* results = new float[num_queries];
nb::capsule owner(results, [](void* p) noexcept {
delete[] reinterpret_cast<float*>(p);
});
// Compute the interpolated values
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
const OctreeIndex node_index{
index_view(query_idx, 0),
{index_view(query_idx, 1), index_view(query_idx, 2),
index_view(query_idx, 3)}};
results[query_idx] = query_accelerator.getCellValue(node_index);
}
// Create Python capsule that deallocates the results array when
// all references to it expire
nb::capsule owner(results, [](void* p) noexcept {
delete[] reinterpret_cast<float*>(p);
});
// Return results as numpy array
return nb::ndarray<nb::numpy, float>{
results, {num_queries, 1u}, owner};
Expand All @@ -168,80 +177,68 @@ void add_map_bindings(nb::module_& m) {
"Query the map at the given node indices, provided as a matrix with "
"one (height, x, y, z) node index per row.")
.def(
"interpolateNearest",
[](const MapBase& self, const Point3D& position) {
return interpolate::nearestNeighbor(self, position);
},
"position"_a,
"Query the map's value at a point using nearest neighbor "
"interpolation.")
.def(
"interpolateTrilinear",
[](const MapBase& self, const Point3D& position) {
return interpolate::trilinear(self, position);
},
"position"_a,
"Query the map's value at a point using trilinear interpolation.")
.def(
"interpolateNearest",
[](const HashedWaveletOctree& self,
const nb::ndarray<FloatingPoint, nb::shape<-1, 3>,
nb::device::cpu>& positions) {
// Create a query accelerator
QueryAccelerator<HashedWaveletOctree> query_accelerator{self};
// Create nb::ndarray view for efficient access to the query points
const auto positions_view = positions.view();
const auto num_queries = positions_view.shape(0);
// Allocate and populate raw results array
auto* results = new float[num_queries];
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
results[query_idx] = interpolate::nearestNeighbor(
query_accelerator,
{positions_view(query_idx, 0), positions_view(query_idx, 1),
positions_view(query_idx, 2)});
"interpolate",
[](const MapBase& self, const Point3D& position,
InterpolationMode mode) {
switch (mode) {
case InterpolationMode::kNearest:
return interpolate::nearestNeighbor(self, position);
case InterpolationMode::kTrilinear:
return interpolate::trilinear(self, position);
default:
throw nb::type_error("Unknown interpolation mode.");
}
// Create Python capsule that deallocates the results array when
// all references to it expire
nb::capsule owner(results, [](void* p) noexcept {
delete[] reinterpret_cast<float*>(p);
});
// Return results as numpy array
return nb::ndarray<nb::numpy, float>{
results, {num_queries, 1u}, owner};
},
"position_list"_a,
"Query the map's value at the given points using nearest neighbor "
"interpolation.")
"position"_a, "mode"_a = InterpolationMode::kTrilinear,
"Query the map's value at a point, using the specified interpolation "
"mode.")
.def(
"interpolateTrilinear",
"interpolate",
[](const HashedWaveletOctree& self,
const nb::ndarray<FloatingPoint, nb::shape<-1, 3>,
nb::device::cpu>& positions) {
nb::device::cpu>& positions,
InterpolationMode mode) {
// Create a query accelerator
QueryAccelerator<HashedWaveletOctree> query_accelerator{self};
// Create nb::ndarray view for efficient access to the query points
const auto positions_view = positions.view();
const auto num_queries = positions_view.shape(0);
// Allocate and populate raw results array
// Create the raw results array and wrap it in a Python capsule that
// deallocates it when all references to it expire
auto* results = new float[num_queries];
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
results[query_idx] = interpolate::trilinear(
query_accelerator,
{positions_view(query_idx, 0), positions_view(query_idx, 1),
positions_view(query_idx, 2)});
}
// Create Python capsule that deallocates the results array when
// all references to it expire
nb::capsule owner(results, [](void* p) noexcept {
delete[] reinterpret_cast<float*>(p);
});
// Compute the interpolated values
switch (mode) {
case InterpolationMode::kNearest:
for (size_t query_idx = 0; query_idx < num_queries;
++query_idx) {
results[query_idx] = interpolate::nearestNeighbor(
query_accelerator, {positions_view(query_idx, 0),
positions_view(query_idx, 1),
positions_view(query_idx, 2)});
}
break;
case InterpolationMode::kTrilinear:
for (size_t query_idx = 0; query_idx < num_queries;
++query_idx) {
results[query_idx] = interpolate::trilinear(
query_accelerator, {positions_view(query_idx, 0),
positions_view(query_idx, 1),
positions_view(query_idx, 2)});
}
break;
default:
throw nb::type_error("Unknown interpolation mode.");
}
// Return results as numpy array
return nb::ndarray<nb::numpy, float>{
results, {num_queries, 1u}, owner};
},
"position_list"_a,
"Query the map's value at the given points using trilinear "
"interpolation.");
"position_list"_a, "mode"_a = InterpolationMode::kTrilinear,
"Query the map's value at the given points, using the specified "
"interpolation mode.");

nb::class_<HashedChunkedWaveletOctree, MapBase>(
m, "HashedChunkedWaveletOctree",
Expand All @@ -254,19 +251,20 @@ void add_map_bindings(nb::module_& m) {
"node_index"_a,
"Query the value of the map at a given octree node index.")
.def(
"interpolateNearest",
[](const MapBase& self, const Point3D& position) {
return interpolate::nearestNeighbor(self, position);
},
"position"_a,
"Query the map's value at a point using nearest neighbor "
"interpolation.")
.def(
"interpolateTrilinear",
[](const MapBase& self, const Point3D& position) {
return interpolate::trilinear(self, position);
"interpolate",
[](const MapBase& self, const Point3D& position,
InterpolationMode mode) {
switch (mode) {
case InterpolationMode::kNearest:
return interpolate::nearestNeighbor(self, position);
case InterpolationMode::kTrilinear:
return interpolate::trilinear(self, position);
default:
throw nb::type_error("Unknown interpolation mode.");
}
},
"position"_a,
"Query the map's value at a point using trilinear interpolation.");
"position"_a, "mode"_a = InterpolationMode::kTrilinear,
"Query the map's value at a point, using the specified interpolation "
"mode.");
}
} // namespace wavemap
3 changes: 2 additions & 1 deletion library/python/src/pywavemap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from ._pywavemap_bindings import (Rotation, Pose, Pointcloud, PosedPointcloud,
Image, PosedImage)
from ._pywavemap_bindings import (Map, HashedWaveletOctree,
HashedChunkedWaveletOctree)
HashedChunkedWaveletOctree,
InterpolationMode)
from ._pywavemap_bindings import Pipeline

# Binding submodules
Expand Down
17 changes: 10 additions & 7 deletions library/python/test/test_pywavemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_batched_fixed_resolution_queries():

def test_batched_multi_resolution_queries():
import numpy as np
import pywavemap
import pywavemap as wave

test_map = load_test_map()

Expand All @@ -38,33 +38,36 @@ def test_batched_multi_resolution_queries():
cell_indices = np.concatenate((cell_heights, cell_positions), axis=1)
cell_values = test_map.get_cell_values(cell_indices)
for cell_idx in range(cell_positions.shape[0]):
cell_index = pywavemap.OctreeIndex(cell_heights[cell_idx],
cell_positions[cell_idx, :])
cell_index = wave.OctreeIndex(cell_heights[cell_idx],
cell_positions[cell_idx, :])
cell_value = test_map.get_cell_value(cell_index)
assert cell_values[cell_idx] == cell_value


def test_batched_nearest_neighbor_interpolation():
import numpy as np
from pywavemap import InterpolationMode

test_map = load_test_map()

points = np.random.random(size=(64 * 64 * 32, 3))
points_log_odds = test_map.interpolateNearest(points)
points_log_odds = test_map.interpolate(points, InterpolationMode.NEAREST)
for point_idx in range(points.shape[0]):
point = points[point_idx, :]
point_log_odds = test_map.interpolateNearest(point)
point_log_odds = test_map.interpolate(point, InterpolationMode.NEAREST)
assert points_log_odds[point_idx] == point_log_odds


def test_batched_trilinear_interpolation():
import numpy as np
from pywavemap import InterpolationMode

test_map = load_test_map()

points = np.random.random(size=(64 * 64 * 32, 3))
points_log_odds = test_map.interpolateTrilinear(points)
points_log_odds = test_map.interpolate(points, InterpolationMode.TRILINEAR)
for point_idx in range(points.shape[0]):
point = points[point_idx, :]
point_log_odds = test_map.interpolateTrilinear(point)
point_log_odds = test_map.interpolate(point,
InterpolationMode.TRILINEAR)
assert points_log_odds[point_idx] == point_log_odds

0 comments on commit 8a4889d

Please sign in to comment.