Skip to content

Commit fb88dea

Browse files
committed
[RFC] Design document of Open3D for paddle backend
1 parent 6c37103 commit fb88dea

File tree

2 files changed

+331
-0
lines changed

2 files changed

+331
-0
lines changed
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
# 飞桨适配 Open3D
2+
3+
> RFC 文档相关记录信息
4+
5+
| | |
6+
| ------------ | -------------------- |
7+
| 提交作者 | SecretXV |
8+
| 提交时间 | 2024-10-18 |
9+
| RFC 版本号 | v1.0 |
10+
| 依赖飞桨版本 | develop |
11+
| 文件名 | 20241018_open3d_for_paddle.md |
12+
13+
## 1. 概述
14+
15+
### 1.1 相关背景
16+
17+
Open3D是一个开源库,专门用于处理三维数据的计算机视觉和图形学任务。它由英伟达(NVIDIA)资助,并且广泛应用于学术研究和工业界。 Open3D提供了丰富的工具集,用于 3D数据处理、可视化、机器学习、和 几何运算,使得用户可以高效地处理点云、网格、体素等三维数据。
18+
19+
### 1.2 功能目标
20+
21+
1. 完成 Open3D 和 Paddle 联合编译
22+
2. 完成 PyTorch 算子向 Paddle 迁移,原则上所有算子都需要迁移
23+
3. 完成 Paddle Layer 封装,结构与 PyTorch 保持一致
24+
4. 完成相关代码合入 Open3D 仓库
25+
26+
### 1.3 意义
27+
28+
Open3D 集成了机器学习的功能,简化 3D 数据的机器学习任务,例如点云分类、分割和物体检测。它提供了易于使用的 API 和工具集,能够与主流深度学习框架(如 PyTorch 和 TensorFlow)无缝集成。通过为 Open3D 适配 paddle 后端,可以将 PyTorch 上的代码快速迁移至 Paddle,丰富 Paddle 在点云分类、分割和物体检测等场景的能力。
29+
30+
## 2. Open3D 现状
31+
32+
Open3D中实现了机器学习相关的模块,已经支持 PyTorch 和 Tensorflow 后端,具体包括:相关算子支持,dataset,loss,pipeline,visualize功能支持。其中,dataset,loss,pipeline,visualize功能具体在Open3D-ML中实现。
33+
34+
> NOTE:本次适配计划 Open3D-ML 非强制要求
35+
36+
## 3. 目标调研
37+
38+
### 3.1 后端集成方案调研
39+
不用CppExtension/CudaExtension,在Open3D中采用CMake组织自定义算子编译
40+
1. 优点:与pytorch和tf的自定义算子编译方式保持统一
41+
2. 缺点:需要自行组织算子编译;需要将paddle的cpp_extension在Open3D中进行实现。
42+
43+
采用CppExtension/CudaExtension进行自定义算子编译
44+
1. 优点:实现便捷
45+
2. 缺点:采用CppExtension/CudaExtension编译的自定义算子会生成一个独立的模块,需要考虑如何自动化的打包到whl包中;算子存在第三方依赖,需要实现 CMake 和 setup.py 的参数透传
46+
### 3.2 算子注册方案调研
47+
Open3D中torch算子根据是否有反向算子,存在两种注册方式:
48+
1. 对于存在反向的算子(e.g ContinuousConv),会创建一个继承自torch.autograd.Function的类,并实现forward和backward方法。本质是为了让autograd根据forward函数查找到对应的backward函数
49+
50+
```cpp
51+
using namespace open3d::ml::impl;
52+
using torch::autograd::AutogradContext;
53+
using torch::autograd::Function;
54+
using torch::autograd::Variable;
55+
using torch::autograd::variable_list;
56+
57+
class ContinuousConvFunction : public Function<ContinuousConvFunction> {
58+
public:
59+
static Variable forward(AutogradContext* ctx,
60+
Variable filters,
61+
Variable out_positions,
62+
Variable extents,
63+
Variable offset,
64+
Variable inp_positions,
65+
Variable inp_features,
66+
Variable inp_importance,
67+
Variable neighbors_index,
68+
Variable neighbors_importance,
69+
Variable neighbors_row_splits,
70+
const bool align_corners,
71+
const std::string& coordinate_mapping_str,
72+
const bool normalize,
73+
const std::string& interpolation_str,
74+
const int64_t max_temp_mem_MB) {
75+
...
76+
}
77+
78+
static variable_list backward(AutogradContext* ctx,
79+
variable_list grad_output) {
80+
...
81+
};
82+
torch::Tensor ContinuousConv(const torch::Tensor& filters,
83+
const torch::Tensor& out_positions,
84+
const torch::Tensor& extents,
85+
const torch::Tensor& offset,
86+
const torch::Tensor& inp_positions,
87+
const torch::Tensor& inp_features,
88+
const torch::Tensor& inp_importance,
89+
const torch::Tensor& neighbors_index,
90+
const torch::Tensor& neighbors_importance,
91+
const torch::Tensor& neighbors_row_splits,
92+
const bool align_corners,
93+
const std::string& coordinate_mapping_str,
94+
const bool normalize,
95+
const std::string& interpolation_str,
96+
const int64_t max_temp_mem_MB) {
97+
auto ans = ContinuousConvFunction::apply(
98+
filters, out_positions, extents, offset, inp_positions,
99+
inp_features, inp_importance, neighbors_index, neighbors_importance,
100+
neighbors_row_splits, align_corners, coordinate_mapping_str,
101+
normalize, interpolation_str, max_temp_mem_MB);
102+
return ans;
103+
}
104+
105+
static auto registry = torch::RegisterOperators(
106+
"open3d::continuous_conv(Tensor filters, Tensor out_positions, Tensor "
107+
"extents, Tensor offset, Tensor inp_positions, Tensor inp_features, "
108+
"Tensor inp_importance, Tensor neighbors_index, Tensor "
109+
"neighbors_importance, Tensor neighbors_row_splits, bool "
110+
"align_corners=False, str coordinate_mapping=\"ball_to_cube_radial\", "
111+
"bool normalize=False, str interpolation=\"linear\", int "
112+
"max_temp_mem_MB=64) -> Tensor",
113+
&::ContinuousConv);
114+
```
115+
116+
2. 对于只有前向的算子(e.g BuildSpatialHashTable),只注册前向算子
117+
```cpp
118+
#include <vector>
119+
120+
#include "open3d/ml/pytorch/TorchHelper.h"
121+
#include "torch/script.h"
122+
123+
template <class T>
124+
void BuildSpatialHashTableCPU(const torch::Tensor& points,
125+
double radius,
126+
const torch::Tensor& points_row_splits,
127+
const std::vector<uint32_t>& hash_table_splits,
128+
torch::Tensor& hash_table_index,
129+
torch::Tensor& hash_table_cell_splits);
130+
#ifdef BUILD_CUDA_MODULE
131+
template <class T>
132+
void BuildSpatialHashTableCUDA(const torch::Tensor& points,
133+
double radius,
134+
const torch::Tensor& points_row_splits,
135+
const std::vector<uint32_t>& hash_table_splits,
136+
torch::Tensor& hash_table_index,
137+
torch::Tensor& hash_table_cell_splits);
138+
#endif
139+
140+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> BuildSpatialHashTable(
141+
torch::Tensor points,
142+
double radius,
143+
torch::Tensor points_row_splits,
144+
double hash_table_size_factor,
145+
int64_t max_hash_table_size) {
146+
...
147+
}
148+
149+
static auto registry = torch::RegisterOperators(
150+
"open3d::build_spatial_hash_table(Tensor points, float radius, Tensor "
151+
"points_row_splits, float hash_table_size_factor, int "
152+
"max_hash_table_size=33554432) -> (Tensor hash_table_index, Tensor "
153+
"hash_table_cell_splits, Tensor hash_table_splits)",
154+
&BuildSpatialHashTable);
155+
```
156+
157+
Paddle自定义算子与torch存在一些差别,Paddle支持直接将func注册为对应op的前向和反向算子。
158+
```cpp
159+
// forward
160+
std::vector<paddle::Tensor> Forward(
161+
const paddle::Tensor& x,
162+
bool bool_attr,
163+
int int_attr,
164+
float float_attr,
165+
int64_t int64_attr,
166+
const std::string& str_attr,
167+
const std::vector<int>& int_vec_attr,
168+
const std::vector<float>& float_vec_attr,
169+
const std::vector<int64_t>& int64_vec_attr,
170+
const std::vector<std::string>& str_vec_attr) {...}
171+
172+
// forward函数注册
173+
PD_BUILD_OP(op_name)
174+
.Inputs({"X"})
175+
.Outputs({"Out"})
176+
.Attrs({"bool_attr: bool",
177+
"int_attr: int",
178+
"float_attr: float",
179+
"int64_attr: int64_t",
180+
"str_attr: std::string",
181+
"int_vec_attr: std::vector<int>",
182+
"float_vec_attr: std::vector<float>",
183+
"int64_vec_attr: std::vector<int64_t>",
184+
"str_vec_attr: std::vector<std::string>"})
185+
.SetKernelFn(PD_KERNEL(Forward));
186+
187+
// backward
188+
std::vector<paddle::Tensor> Backward(
189+
const paddle::Tensor& grad_out,
190+
int int_attr,
191+
const std::vector<float>& float_vec_attr,
192+
const std::vector<std::string>& str_vec_attr) {...}
193+
194+
PD_BUILD_GRAD_OP(op_name)
195+
.Inputs({paddle::Grad("Out")})
196+
.Outputs({paddle::Grad("X")})
197+
.Attrs({"int_attr: int",
198+
"float_vec_attr: std::vector<float>",
199+
"str_vec_attr: std::vector<std::string>"})
200+
.SetKernelFn(PD_KERNEL(Backward));
201+
```
202+
203+
## 4. 设计思路与实现方案
204+
205+
### 4.1 Paddle后端集成方案
206+
207+
1. 采用 Paddle 的自定义算子接口进行C++算子注册,参考CppExtension在Open3D中实现自定义算子python接口codegen对应逻辑。
208+
2. 不用CppExtension/CudaExtension组织编译,在Open3D中采用cmake组织自定义算子编译。
209+
210+
### 4.2 算子及对应单测梳理
211+
212+
算子路径:
213+
1. cpp/open3d/ml/pytorch/*
214+
215+
Pytorch Module路径:
216+
1. python/open3d/ml/torch/classes/ragged_tensor.py
217+
2. python/open3d/ml/torch/python/layers*
218+
219+
单测路径:
220+
1. python/test/ml_ops,
221+
222+
| 算子名称 | 需要反向 | 分类 | 对应 pytorch module | 对应单测 | 优先级 |
223+
| ----------------------------- | -------- | --------------- | ------------------------------------------------------------------ | ----------------------------------------------------------------------- | ----------------------- |
224+
| continuous_conv | Yes | continuous_conv | ContinuousConv, SparseConv, SparseConvTranspose | test_cconv.py | P0 |
225+
| build_spatial_hash_table | No | misc | FixedRadiusSearch | 无 | P0 |
226+
| continuous_conv_transpose | Yes | continuous_conv | SparseConvTranspose | test_cconv.py | P0 |
227+
| fixed_radius_search | No | misc | ContinuousConv, SparseConv, SparseConvTranspose, FixedRadiusSearch | test_fixed_radius_search.py | P0 |
228+
| invert_neighbors_list | No | misc | SparseConvTranspose | test_general_sparseconv.py, test_cconv.py | P0 |
229+
| knn_search | No | misc | KNNSearch | test_knn_search.py | P0 |
230+
| radius_search | No | misc | RadiusSearch | test_radius_search.py | P0 |
231+
| voxel_pooling | Yes | misc | VoxelPooling | test_voxel_pooling.py | P0 |
232+
| ragged_to_dense | No | misc | | RaggedTensor | test_ragged_to_dense.py | P0 |
233+
| nms | No | misc | 无 | test_nms.py | P1 |
234+
| reduce_subarrays_sum | No | misc | 无 | test_cconv.py, test_general_sparseconv.py, test_reduce_subarrays_sum.py | P1 |
235+
| voxelize | No | misc | 无 | test_voxelize.py | P1 |
236+
| roi_pool | No | misc | 无 | test_roi_pool.py | P1 |
237+
| ball_query | No | pointnet | 无 | test_query_pts.py | P1 |
238+
| three_nn | No | pointnet | 无 | test_three_nn.py | P1 |
239+
| three_interpolate | No | pointnet | 无 | test_three_interp.py | P1 |
240+
| furthest_point_sampling | No | pointnet | 无 | test_sampling.py | P1 |
241+
| sparse_conv | Yes | sparse_conv | 无 | test_general_sparseconv.py, test_sparseconv.py | P1 |
242+
| sparse_conv_transpose | Yes | sparse_conv | 无 | test_general_sparseconv.py,test_sparseconv.py | P1 |
243+
| three_interpolate_grad | No | pointnet | 无 | 无 | P2 |
244+
| trilinear_devoxelize_forward | No | pvcnn | 无 | 无 | P2 |
245+
| trilinear_devoxelize_backward | No | pvcnn | 无 | 无 | P2 |
246+
247+
> NOTE: P0 和 P1 优先级算子必须实现,P2优先级算子的选择实现
248+
249+
### 4.3 方案验证
250+
251+
基于上述方案,已初步完成验证,测试代码及结果如下:
252+
253+
```python
254+
import open3d.ml.paddle as p_ml3d
255+
import paddle
256+
import open3d.ml.torch as t_ml3d
257+
import torch
258+
import numpy as np
259+
260+
class p_Search(paddle.nn.Layer):
261+
def __init__(self, r: float):
262+
super().__init__()
263+
self.r = r
264+
self.search = p_ml3d.layers.FixedRadiusSearch(return_distances=True)
265+
266+
def forward(self, x, y):
267+
nei = self.search(x, y, self.r)
268+
return nei
269+
270+
class t_Search(torch.nn.Module):
271+
def __init__(self, r: float):
272+
super().__init__()
273+
self.r = r
274+
self.search = t_ml3d.layers.FixedRadiusSearch(return_distances=True)
275+
276+
def forward(self, x, y):
277+
nei = self.search(x, y, self.r)
278+
return nei
279+
280+
net_p = p_Search(r=0.8)
281+
net_t = t_Search(r=0.8)
282+
283+
x= np.random.randn(20,3)
284+
y=np.random.randn(10,3)
285+
286+
xp = paddle.to_tensor(x,dtype='float32').cpu()
287+
yp = paddle.to_tensor(y,dtype='float32').cpu()
288+
289+
290+
xt = torch.from_numpy(x).float().cpu()
291+
yt = torch.from_numpy(y).float().cpu()
292+
293+
out_p = net_p(xp,yp)
294+
out_t = net_t(xt,yt)
295+
296+
print(out_p)
297+
print(out_t)
298+
```
299+
300+
运行结果:
301+
![image](./images/open3d_for_paddle_demo_result.png)
302+
303+
304+
305+
## 5. 测试和验收的考量
306+
307+
1. P0 和 P1 级别优先级的算子对应单测 100% 通过
308+
2. Open3D 现有 CI 100%通过
309+
3. [可选] 为 Open3D 新增 Paddle CI 100% 通过
310+
311+
## 6. 可行性分析和排期规划
312+
313+
整体来看,项目可行,具体排期规划如下:
314+
315+
| 里程碑 | 人力(人日) | 时间点 |
316+
| ------------------------------------------------------------------ | ------------ | ---------- |
317+
| 编译 Open3D,跑通 pytorch example | 1 | |
318+
| 参考 pytorch backend ,梳理需要改写的C++算子和对应 paddle module。 | 0.5 | |
319+
| 完成 Open3D + paddle 编译 + 1 * demo | 3 | |
320+
| 提交RFC | 1 | 2024/10/18 |
321+
| 为 paddle 实现 C++ 算子和对应 module,编写并通过相关单测验证 | 20 | |
322+
| 调通 Open3D paddle CI | 5 | |
323+
| 提交PR,完成合入 | 5 | 2024/11/25 |
324+
| [可选] Open3D-ML 适配 | 20 | long term |
325+
326+
## 7. 影响面
327+
328+
### 7.1 PaddleScience
329+
代码合入 Open3D 仓库,对 PaddleScience 无影响。
330+
### 7.2 Open3D
331+
本次适配为多后端支持,不影响 Open3D 现有功能,需要为Open3D增加对应paddle backend 的 CI。
278 KB
Loading

0 commit comments

Comments
 (0)