|
| 1 | +# 飞桨适配 Open3D |
| 2 | + |
| 3 | +> RFC 文档相关记录信息 |
| 4 | +
|
| 5 | +| | | |
| 6 | +| ------------ | -------------------- | |
| 7 | +| 提交作者 | SecretXV | |
| 8 | +| 提交时间 | 2024-10-18 | |
| 9 | +| RFC 版本号 | v1.0 | |
| 10 | +| 依赖飞桨版本 | develop | |
| 11 | +| 文件名 | 20241018_open3d_for_paddle.md | |
| 12 | + |
| 13 | +## 1. 概述 |
| 14 | + |
| 15 | +### 1.1 相关背景 |
| 16 | + |
| 17 | +Open3D是一个开源库,专门用于处理三维数据的计算机视觉和图形学任务。它由英伟达(NVIDIA)资助,并且广泛应用于学术研究和工业界。 Open3D提供了丰富的工具集,用于 3D数据处理、可视化、机器学习、和 几何运算,使得用户可以高效地处理点云、网格、体素等三维数据。 |
| 18 | + |
| 19 | +### 1.2 功能目标 |
| 20 | + |
| 21 | +1. 完成 Open3D 和 Paddle 联合编译 |
| 22 | +2. 完成 PyTorch 算子向 Paddle 迁移,原则上所有算子都需要迁移 |
| 23 | +3. 完成 Paddle Layer 封装,结构与 PyTorch 保持一致 |
| 24 | +4. 完成相关代码合入 Open3D 仓库 |
| 25 | + |
| 26 | +### 1.3 意义 |
| 27 | + |
| 28 | +Open3D 集成了机器学习的功能,简化 3D 数据的机器学习任务,例如点云分类、分割和物体检测。它提供了易于使用的 API 和工具集,能够与主流深度学习框架(如 PyTorch 和 TensorFlow)无缝集成。通过为 Open3D 适配 paddle 后端,可以将 PyTorch 上的代码快速迁移至 Paddle,丰富 Paddle 在点云分类、分割和物体检测等场景的能力。 |
| 29 | + |
| 30 | +## 2. Open3D 现状 |
| 31 | + |
| 32 | +Open3D中实现了机器学习相关的模块,已经支持 PyTorch 和 Tensorflow 后端,具体包括:相关算子支持,dataset,loss,pipeline,visualize功能支持。其中,dataset,loss,pipeline,visualize功能具体在Open3D-ML中实现。 |
| 33 | + |
| 34 | +> NOTE:本次适配计划 Open3D-ML 非强制要求 |
| 35 | +
|
| 36 | +## 3. 目标调研 |
| 37 | + |
| 38 | +### 3.1 后端集成方案调研 |
| 39 | +不用CppExtension/CudaExtension,在Open3D中采用CMake组织自定义算子编译 |
| 40 | + 1. 优点:与pytorch和tf的自定义算子编译方式保持统一 |
| 41 | + 2. 缺点:需要自行组织算子编译;需要将paddle的cpp_extension在Open3D中进行实现。 |
| 42 | + |
| 43 | +采用CppExtension/CudaExtension进行自定义算子编译 |
| 44 | + 1. 优点:实现便捷 |
| 45 | + 2. 缺点:采用CppExtension/CudaExtension编译的自定义算子会生成一个独立的模块,需要考虑如何自动化的打包到whl包中;算子存在第三方依赖,需要实现 CMake 和 setup.py 的参数透传 |
| 46 | +### 3.2 算子注册方案调研 |
| 47 | +Open3D中torch算子根据是否有反向算子,存在两种注册方式: |
| 48 | +1. 对于存在反向的算子(e.g ContinuousConv),会创建一个继承自torch.autograd.Function的类,并实现forward和backward方法。本质是为了让autograd根据forward函数查找到对应的backward函数 |
| 49 | + |
| 50 | +```cpp |
| 51 | +using namespace open3d::ml::impl; |
| 52 | +using torch::autograd::AutogradContext; |
| 53 | +using torch::autograd::Function; |
| 54 | +using torch::autograd::Variable; |
| 55 | +using torch::autograd::variable_list; |
| 56 | + |
| 57 | +class ContinuousConvFunction : public Function<ContinuousConvFunction> { |
| 58 | +public: |
| 59 | + static Variable forward(AutogradContext* ctx, |
| 60 | + Variable filters, |
| 61 | + Variable out_positions, |
| 62 | + Variable extents, |
| 63 | + Variable offset, |
| 64 | + Variable inp_positions, |
| 65 | + Variable inp_features, |
| 66 | + Variable inp_importance, |
| 67 | + Variable neighbors_index, |
| 68 | + Variable neighbors_importance, |
| 69 | + Variable neighbors_row_splits, |
| 70 | + const bool align_corners, |
| 71 | + const std::string& coordinate_mapping_str, |
| 72 | + const bool normalize, |
| 73 | + const std::string& interpolation_str, |
| 74 | + const int64_t max_temp_mem_MB) { |
| 75 | + ... |
| 76 | + } |
| 77 | + |
| 78 | + static variable_list backward(AutogradContext* ctx, |
| 79 | + variable_list grad_output) { |
| 80 | + ... |
| 81 | +}; |
| 82 | +torch::Tensor ContinuousConv(const torch::Tensor& filters, |
| 83 | + const torch::Tensor& out_positions, |
| 84 | + const torch::Tensor& extents, |
| 85 | + const torch::Tensor& offset, |
| 86 | + const torch::Tensor& inp_positions, |
| 87 | + const torch::Tensor& inp_features, |
| 88 | + const torch::Tensor& inp_importance, |
| 89 | + const torch::Tensor& neighbors_index, |
| 90 | + const torch::Tensor& neighbors_importance, |
| 91 | + const torch::Tensor& neighbors_row_splits, |
| 92 | + const bool align_corners, |
| 93 | + const std::string& coordinate_mapping_str, |
| 94 | + const bool normalize, |
| 95 | + const std::string& interpolation_str, |
| 96 | + const int64_t max_temp_mem_MB) { |
| 97 | + auto ans = ContinuousConvFunction::apply( |
| 98 | + filters, out_positions, extents, offset, inp_positions, |
| 99 | + inp_features, inp_importance, neighbors_index, neighbors_importance, |
| 100 | + neighbors_row_splits, align_corners, coordinate_mapping_str, |
| 101 | + normalize, interpolation_str, max_temp_mem_MB); |
| 102 | + return ans; |
| 103 | +} |
| 104 | + |
| 105 | +static auto registry = torch::RegisterOperators( |
| 106 | + "open3d::continuous_conv(Tensor filters, Tensor out_positions, Tensor " |
| 107 | + "extents, Tensor offset, Tensor inp_positions, Tensor inp_features, " |
| 108 | + "Tensor inp_importance, Tensor neighbors_index, Tensor " |
| 109 | + "neighbors_importance, Tensor neighbors_row_splits, bool " |
| 110 | + "align_corners=False, str coordinate_mapping=\"ball_to_cube_radial\", " |
| 111 | + "bool normalize=False, str interpolation=\"linear\", int " |
| 112 | + "max_temp_mem_MB=64) -> Tensor", |
| 113 | + &::ContinuousConv); |
| 114 | +``` |
| 115 | +
|
| 116 | +2. 对于只有前向的算子(e.g BuildSpatialHashTable),只注册前向算子 |
| 117 | +```cpp |
| 118 | +#include <vector> |
| 119 | +
|
| 120 | +#include "open3d/ml/pytorch/TorchHelper.h" |
| 121 | +#include "torch/script.h" |
| 122 | +
|
| 123 | +template <class T> |
| 124 | +void BuildSpatialHashTableCPU(const torch::Tensor& points, |
| 125 | + double radius, |
| 126 | + const torch::Tensor& points_row_splits, |
| 127 | + const std::vector<uint32_t>& hash_table_splits, |
| 128 | + torch::Tensor& hash_table_index, |
| 129 | + torch::Tensor& hash_table_cell_splits); |
| 130 | +#ifdef BUILD_CUDA_MODULE |
| 131 | +template <class T> |
| 132 | +void BuildSpatialHashTableCUDA(const torch::Tensor& points, |
| 133 | + double radius, |
| 134 | + const torch::Tensor& points_row_splits, |
| 135 | + const std::vector<uint32_t>& hash_table_splits, |
| 136 | + torch::Tensor& hash_table_index, |
| 137 | + torch::Tensor& hash_table_cell_splits); |
| 138 | +#endif |
| 139 | +
|
| 140 | +std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> BuildSpatialHashTable( |
| 141 | + torch::Tensor points, |
| 142 | + double radius, |
| 143 | + torch::Tensor points_row_splits, |
| 144 | + double hash_table_size_factor, |
| 145 | + int64_t max_hash_table_size) { |
| 146 | + ... |
| 147 | +} |
| 148 | +
|
| 149 | +static auto registry = torch::RegisterOperators( |
| 150 | + "open3d::build_spatial_hash_table(Tensor points, float radius, Tensor " |
| 151 | + "points_row_splits, float hash_table_size_factor, int " |
| 152 | + "max_hash_table_size=33554432) -> (Tensor hash_table_index, Tensor " |
| 153 | + "hash_table_cell_splits, Tensor hash_table_splits)", |
| 154 | + &BuildSpatialHashTable); |
| 155 | +``` |
| 156 | + |
| 157 | +Paddle自定义算子与torch存在一些差别,Paddle支持直接将func注册为对应op的前向和反向算子。 |
| 158 | +```cpp |
| 159 | +// forward |
| 160 | +std::vector<paddle::Tensor> Forward( |
| 161 | + const paddle::Tensor& x, |
| 162 | + bool bool_attr, |
| 163 | + int int_attr, |
| 164 | + float float_attr, |
| 165 | + int64_t int64_attr, |
| 166 | + const std::string& str_attr, |
| 167 | + const std::vector<int>& int_vec_attr, |
| 168 | + const std::vector<float>& float_vec_attr, |
| 169 | + const std::vector<int64_t>& int64_vec_attr, |
| 170 | + const std::vector<std::string>& str_vec_attr) {...} |
| 171 | + |
| 172 | +// forward函数注册 |
| 173 | +PD_BUILD_OP(op_name) |
| 174 | + .Inputs({"X"}) |
| 175 | + .Outputs({"Out"}) |
| 176 | + .Attrs({"bool_attr: bool", |
| 177 | + "int_attr: int", |
| 178 | + "float_attr: float", |
| 179 | + "int64_attr: int64_t", |
| 180 | + "str_attr: std::string", |
| 181 | + "int_vec_attr: std::vector<int>", |
| 182 | + "float_vec_attr: std::vector<float>", |
| 183 | + "int64_vec_attr: std::vector<int64_t>", |
| 184 | + "str_vec_attr: std::vector<std::string>"}) |
| 185 | + .SetKernelFn(PD_KERNEL(Forward)); |
| 186 | + |
| 187 | +// backward |
| 188 | +std::vector<paddle::Tensor> Backward( |
| 189 | + const paddle::Tensor& grad_out, |
| 190 | + int int_attr, |
| 191 | + const std::vector<float>& float_vec_attr, |
| 192 | + const std::vector<std::string>& str_vec_attr) {...} |
| 193 | + |
| 194 | +PD_BUILD_GRAD_OP(op_name) |
| 195 | + .Inputs({paddle::Grad("Out")}) |
| 196 | + .Outputs({paddle::Grad("X")}) |
| 197 | + .Attrs({"int_attr: int", |
| 198 | + "float_vec_attr: std::vector<float>", |
| 199 | + "str_vec_attr: std::vector<std::string>"}) |
| 200 | + .SetKernelFn(PD_KERNEL(Backward)); |
| 201 | +``` |
| 202 | +
|
| 203 | +## 4. 设计思路与实现方案 |
| 204 | +
|
| 205 | +### 4.1 Paddle后端集成方案 |
| 206 | +
|
| 207 | +1. 采用 Paddle 的自定义算子接口进行C++算子注册,参考CppExtension在Open3D中实现自定义算子python接口codegen对应逻辑。 |
| 208 | +2. 不用CppExtension/CudaExtension组织编译,在Open3D中采用cmake组织自定义算子编译。 |
| 209 | +
|
| 210 | +### 4.2 算子及对应单测梳理 |
| 211 | +
|
| 212 | +算子路径: |
| 213 | +1. cpp/open3d/ml/pytorch/* |
| 214 | +
|
| 215 | +Pytorch Module路径: |
| 216 | +1. python/open3d/ml/torch/classes/ragged_tensor.py |
| 217 | +2. python/open3d/ml/torch/python/layers* |
| 218 | +
|
| 219 | +单测路径: |
| 220 | +1. python/test/ml_ops, |
| 221 | +
|
| 222 | +| 算子名称 | 需要反向 | 分类 | 对应 pytorch module | 对应单测 | 优先级 | |
| 223 | +| ----------------------------- | -------- | --------------- | ------------------------------------------------------------------ | ----------------------------------------------------------------------- | ----------------------- | |
| 224 | +| continuous_conv | Yes | continuous_conv | ContinuousConv, SparseConv, SparseConvTranspose | test_cconv.py | P0 | |
| 225 | +| build_spatial_hash_table | No | misc | FixedRadiusSearch | 无 | P0 | |
| 226 | +| continuous_conv_transpose | Yes | continuous_conv | SparseConvTranspose | test_cconv.py | P0 | |
| 227 | +| fixed_radius_search | No | misc | ContinuousConv, SparseConv, SparseConvTranspose, FixedRadiusSearch | test_fixed_radius_search.py | P0 | |
| 228 | +| invert_neighbors_list | No | misc | SparseConvTranspose | test_general_sparseconv.py, test_cconv.py | P0 | |
| 229 | +| knn_search | No | misc | KNNSearch | test_knn_search.py | P0 | |
| 230 | +| radius_search | No | misc | RadiusSearch | test_radius_search.py | P0 | |
| 231 | +| voxel_pooling | Yes | misc | VoxelPooling | test_voxel_pooling.py | P0 | |
| 232 | +| ragged_to_dense | No | misc | | RaggedTensor | test_ragged_to_dense.py | P0 | |
| 233 | +| nms | No | misc | 无 | test_nms.py | P1 | |
| 234 | +| reduce_subarrays_sum | No | misc | 无 | test_cconv.py, test_general_sparseconv.py, test_reduce_subarrays_sum.py | P1 | |
| 235 | +| voxelize | No | misc | 无 | test_voxelize.py | P1 | |
| 236 | +| roi_pool | No | misc | 无 | test_roi_pool.py | P1 | |
| 237 | +| ball_query | No | pointnet | 无 | test_query_pts.py | P1 | |
| 238 | +| three_nn | No | pointnet | 无 | test_three_nn.py | P1 | |
| 239 | +| three_interpolate | No | pointnet | 无 | test_three_interp.py | P1 | |
| 240 | +| furthest_point_sampling | No | pointnet | 无 | test_sampling.py | P1 | |
| 241 | +| sparse_conv | Yes | sparse_conv | 无 | test_general_sparseconv.py, test_sparseconv.py | P1 | |
| 242 | +| sparse_conv_transpose | Yes | sparse_conv | 无 | test_general_sparseconv.py,test_sparseconv.py | P1 | |
| 243 | +| three_interpolate_grad | No | pointnet | 无 | 无 | P2 | |
| 244 | +| trilinear_devoxelize_forward | No | pvcnn | 无 | 无 | P2 | |
| 245 | +| trilinear_devoxelize_backward | No | pvcnn | 无 | 无 | P2 | |
| 246 | +
|
| 247 | +> NOTE: P0 和 P1 优先级算子必须实现,P2优先级算子的选择实现 |
| 248 | +
|
| 249 | +### 4.3 方案验证 |
| 250 | +
|
| 251 | +基于上述方案,已初步完成验证,测试代码及结果如下: |
| 252 | +
|
| 253 | +```python |
| 254 | +import open3d.ml.paddle as p_ml3d |
| 255 | +import paddle |
| 256 | +import open3d.ml.torch as t_ml3d |
| 257 | +import torch |
| 258 | +import numpy as np |
| 259 | +
|
| 260 | +class p_Search(paddle.nn.Layer): |
| 261 | + def __init__(self, r: float): |
| 262 | + super().__init__() |
| 263 | + self.r = r |
| 264 | + self.search = p_ml3d.layers.FixedRadiusSearch(return_distances=True) |
| 265 | +
|
| 266 | + def forward(self, x, y): |
| 267 | + nei = self.search(x, y, self.r) |
| 268 | + return nei |
| 269 | +
|
| 270 | +class t_Search(torch.nn.Module): |
| 271 | + def __init__(self, r: float): |
| 272 | + super().__init__() |
| 273 | + self.r = r |
| 274 | + self.search = t_ml3d.layers.FixedRadiusSearch(return_distances=True) |
| 275 | +
|
| 276 | + def forward(self, x, y): |
| 277 | + nei = self.search(x, y, self.r) |
| 278 | + return nei |
| 279 | +
|
| 280 | +net_p = p_Search(r=0.8) |
| 281 | +net_t = t_Search(r=0.8) |
| 282 | +
|
| 283 | +x= np.random.randn(20,3) |
| 284 | +y=np.random.randn(10,3) |
| 285 | +
|
| 286 | +xp = paddle.to_tensor(x,dtype='float32').cpu() |
| 287 | +yp = paddle.to_tensor(y,dtype='float32').cpu() |
| 288 | +
|
| 289 | +
|
| 290 | +xt = torch.from_numpy(x).float().cpu() |
| 291 | +yt = torch.from_numpy(y).float().cpu() |
| 292 | +
|
| 293 | +out_p = net_p(xp,yp) |
| 294 | +out_t = net_t(xt,yt) |
| 295 | +
|
| 296 | +print(out_p) |
| 297 | +print(out_t) |
| 298 | +``` |
| 299 | + |
| 300 | +运行结果: |
| 301 | + |
| 302 | + |
| 303 | + |
| 304 | + |
| 305 | +## 5. 测试和验收的考量 |
| 306 | + |
| 307 | +1. P0 和 P1 级别优先级的算子对应单测 100% 通过 |
| 308 | +2. Open3D 现有 CI 100%通过 |
| 309 | +3. [可选] 为 Open3D 新增 Paddle CI 100% 通过 |
| 310 | + |
| 311 | +## 6. 可行性分析和排期规划 |
| 312 | + |
| 313 | +整体来看,项目可行,具体排期规划如下: |
| 314 | + |
| 315 | +| 里程碑 | 人力(人日) | 时间点 | |
| 316 | +| ------------------------------------------------------------------ | ------------ | ---------- | |
| 317 | +| 编译 Open3D,跑通 pytorch example | 1 | | |
| 318 | +| 参考 pytorch backend ,梳理需要改写的C++算子和对应 paddle module。 | 0.5 | | |
| 319 | +| 完成 Open3D + paddle 编译 + 1 * demo | 3 | | |
| 320 | +| 提交RFC | 1 | 2024/10/18 | |
| 321 | +| 为 paddle 实现 C++ 算子和对应 module,编写并通过相关单测验证 | 20 | | |
| 322 | +| 调通 Open3D paddle CI | 5 | | |
| 323 | +| 提交PR,完成合入 | 5 | 2024/11/25 | |
| 324 | +| [可选] Open3D-ML 适配 | 20 | long term | |
| 325 | + |
| 326 | +## 7. 影响面 |
| 327 | + |
| 328 | +### 7.1 PaddleScience |
| 329 | +代码合入 Open3D 仓库,对 PaddleScience 无影响。 |
| 330 | +### 7.2 Open3D |
| 331 | +本次适配为多后端支持,不影响 Open3D 现有功能,需要为Open3D增加对应paddle backend 的 CI。 |
0 commit comments