From 82781bd569cc2f784f7352f17e52599aa33e19d5 Mon Sep 17 00:00:00 2001 From: Sergei Sergienko Date: Fri, 18 Oct 2024 18:55:59 +0300 Subject: [PATCH] feat: Added depth extraction functionality --- examples/python/raycast/depth_exctraction.py | 52 ++++++++++ library/python/CMakeLists.txt | 3 +- library/python/include/pywavemap/raycast.h | 12 +++ library/python/src/pywavemap.cc | 4 + library/python/src/pywavemap/__init__.py | 1 + library/python/src/raycast.cc | 99 ++++++++++++++++++++ 6 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 examples/python/raycast/depth_exctraction.py create mode 100644 library/python/include/pywavemap/raycast.h create mode 100644 library/python/src/raycast.cc diff --git a/examples/python/raycast/depth_exctraction.py b/examples/python/raycast/depth_exctraction.py new file mode 100644 index 000000000..068f56b7e --- /dev/null +++ b/examples/python/raycast/depth_exctraction.py @@ -0,0 +1,52 @@ +""" +Depth map extraction from HashedWaveletOctree map at given camera pose and intrinsics +""" + +import time +from pathlib import Path + +import numpy as np +from PIL import Image +import pywavemap as wm + +def save_depth_as_png(depth_map: np.ndarray, out_path: Path): + depth_min = np.min(depth_map) + depth_max = np.max(depth_map) + + # Avoid division by zero in case all values are the same + if depth_max - depth_min > 0: + depth_map_normalized = (depth_map - depth_min) / (depth_max - depth_min) + else: + depth_map_normalized = np.zeros_like(depth_map) + + # Convert floats (meters) to uint8 and save to png + depth_map_8bit = (depth_map_normalized * 255).astype(np.uint8) + image = Image.fromarray(depth_map_8bit) + image.save(out_path) + +if __name__ == "__main__": + map_path = Path.home() / "data/panoptic_mapping/flat_dataset/run2/your_map.wvmp" + out_path = Path(__file__).parent / "depth.png" + camera_cfg = wm.PinholeCameraProjectorConfig( + width=1280, + height=720, + fx=526.21539307, + fy=526.21539307, + cx=642.309021, + cy=368.69949341, + ) # Note: these are intrinsics for Zed 2i + + # Load map from file + map = wm.Map.load(map_path) + + # Create pose + rotation = wm.Rotation(np.eye(3)) + translation = np.zeros(3) + pose = wm.Pose(rotation, translation) + + # Extract depth + t1 = time.perf_counter() + depth = wm.get_depth(map, pose, camera_cfg, 0.1, 10) + t2 = time.perf_counter() + print(f"Depth map of size {camera_cfg.width}x{camera_cfg.height} created in {t2-t1:.02f} seconds") + save_depth_as_png(depth, out_path) \ No newline at end of file diff --git a/library/python/CMakeLists.txt b/library/python/CMakeLists.txt index 23ad9ab77..49a0e7a0d 100644 --- a/library/python/CMakeLists.txt +++ b/library/python/CMakeLists.txt @@ -63,7 +63,8 @@ nanobind_add_module(_pywavemap_bindings STABLE_ABI src/maps.cc src/measurements.cc src/param.cc - src/pipeline.cc) + src/pipeline.cc + src/raycast.cc) set_wavemap_target_properties(_pywavemap_bindings) target_include_directories(_pywavemap_bindings PRIVATE include) target_link_libraries(_pywavemap_bindings PRIVATE diff --git a/library/python/include/pywavemap/raycast.h b/library/python/include/pywavemap/raycast.h new file mode 100644 index 000000000..cf9feb34d --- /dev/null +++ b/library/python/include/pywavemap/raycast.h @@ -0,0 +1,12 @@ +#ifndef PYWAVEMAP_RAYCAST_H_ +#define PYWAVEMAP_RAYCAST_H_ + +#include + +namespace nb = nanobind; + +namespace wavemap { +void add_raycast_bindings(nb::module_& m); +} // namespace wavemap + +#endif // PYWAVEMAP_CONVERT_H_ diff --git a/library/python/src/pywavemap.cc b/library/python/src/pywavemap.cc index 2cc7ba0ae..c44791623 100644 --- a/library/python/src/pywavemap.cc +++ b/library/python/src/pywavemap.cc @@ -7,6 +7,7 @@ #include "pywavemap/measurements.h" #include "pywavemap/param.h" #include "pywavemap/pipeline.h" +#include "pywavemap/raycast.h" using namespace wavemap; // NOLINT namespace nb = nanobind; @@ -53,4 +54,7 @@ NB_MODULE(_pywavemap_bindings, m) { // Bindings for measurement integration and map update pipelines add_pipeline_bindings(m); + + // Bindings for raycasting + add_raycast_bindings(m); } diff --git a/library/python/src/pywavemap/__init__.py b/library/python/src/pywavemap/__init__.py index f1db669c3..0c4078ab3 100644 --- a/library/python/src/pywavemap/__init__.py +++ b/library/python/src/pywavemap/__init__.py @@ -9,6 +9,7 @@ HashedChunkedWaveletOctree, InterpolationMode) from ._pywavemap_bindings import Pipeline +from ._pywavemap_bindings import raycast, PinholeCameraProjectorConfig, get_depth # Binding submodules from ._pywavemap_bindings import logging, param, convert diff --git a/library/python/src/raycast.cc b/library/python/src/raycast.cc new file mode 100644 index 000000000..9bf4cd694 --- /dev/null +++ b/library/python/src/raycast.cc @@ -0,0 +1,99 @@ +#include "pywavemap/raycast.h" + +#include // to use eigen2numpy seamlessly +#include +#include +#include +#include +#include "wavemap/core/utils/iterate/ray_iterator.h" +#include +#include + +using namespace nb::literals; // NOLINT + +namespace wavemap { +FloatingPoint raycast( + const HashedWaveletOctree& map, + Point3D start_point, + Point3D end_point, + FloatingPoint threshold +) { + const FloatingPoint mcw = map.getMinCellWidth(); + const Ray ray(start_point, end_point, mcw); + for (const Index3D& ray_voxel_index : ray) { + if (map.getCellValue(ray_voxel_index) > threshold) { + const Point3D voxel_center = convert::indexToCenterPoint(ray_voxel_index, mcw); + return (voxel_center - start_point).norm(); + } + } + return (end_point - start_point).norm(); +} + +FloatingPoint raycast_fast( + QueryAccelerator& query_accelerator, + Point3D start_point, + Point3D end_point, + FloatingPoint threshold, + FloatingPoint min_cell_width +) { + const Ray ray(start_point, end_point, min_cell_width); + for (const Index3D& ray_voxel_index : ray) { + if (query_accelerator.getCellValue(ray_voxel_index) > threshold) { + const Point3D voxel_center = convert::indexToCenterPoint(ray_voxel_index, min_cell_width); + return (voxel_center - start_point).norm(); + } + } + return (end_point - start_point).norm(); +} + +void add_raycast_bindings(nb::module_& m) { + nb::class_( + m, + "PinholeCameraProjectorConfig", + "Describes pinhole camera intrinsics" + ) + .def(nb::init(), "fx"_a, "fy"_a, "cx"_a, "cy"_a, "height"_a, "width"_a) + .def_rw("width", &PinholeCameraProjectorConfig::width) + .def_rw("height", &PinholeCameraProjectorConfig::height) + .def_rw("fx", &PinholeCameraProjectorConfig::fx) + .def_rw("fy", &PinholeCameraProjectorConfig::fy) + .def_rw("cx", &PinholeCameraProjectorConfig::cx) + .def_rw("cy", &PinholeCameraProjectorConfig::cy) + .def("__repr__", [](const PinholeCameraProjectorConfig& self) { + return nb::str("PinholeCameraProjectorConfig(width={}, height={}, fx={}, fy={}, cx={}, cy={})") + .format(self.width, self.height, self.fx, self.fy, self.cx, self.cy); + }); + + m.def( + "raycast", + &raycast, + "Raycast and get first point with occopancy higher than threshold" + ); + + m.def( + "raycast_fast", + &raycast_fast, // TODO: unusable without QueryAccelerator binding + "Raycast and get first point with occopancy higher than threshold using QueryAccelerator for efficiency" + ); + + m.def( + "get_depth", + [](const HashedWaveletOctree& map, Transformation3D pose, PinholeCameraProjectorConfig cam_cfg, FloatingPoint threshold, FloatingPoint max_range){ + Image depth_image(cam_cfg.width, cam_cfg.height); + QueryAccelerator query_accelerator(map); + const FloatingPoint mcw = map.getMinCellWidth(); + const PinholeCameraProjector projection_model(cam_cfg); + auto start_point = pose.getPosition(); + for (const Index2D& index: Grid<2>(Index2D::Zero(), depth_image.getDimensions() - Index2D::Ones())) { + const Vector2D image_xy = projection_model.indexToImage(index); + const Point3D C_point = projection_model.sensorToCartesian({image_xy, max_range}); + const Point3D end_point = pose * C_point; + FloatingPoint depth = raycast_fast(query_accelerator, start_point, end_point, threshold, mcw); + depth_image.at(index) = depth; + } + return depth_image.getData().transpose().eval(); + }, + "Extract depth from octree map at using given camera pose and intrinsics" + ); +} +} // namespace wavemap