Skip to content

Commit 1563fb1

Browse files
committed
Initial commit
0 parents  commit 1563fb1

File tree

7 files changed

+495
-0
lines changed

7 files changed

+495
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.DS_Store
2+
._*
3+

LICENSE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) [year] [fullname]
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Chamfer Distance for pyTorch
2+
3+
This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension.
4+
5+
As it is using pyTorch's [JIT compilation](https://pytorch.org/tutorials/advanced/cpp_extension.html), there are no additional prerequisite steps that have to be taken. Simply import the module as shown below; CUDA and C++ code will be compiled on the first run.
6+
7+
### Usage
8+
```python
9+
from chamfer_distance import ChamferDistance
10+
chamfer_dist = ChamferDistance()
11+
12+
#...
13+
# points and points_reconstructed are n_points x 3 matrices
14+
15+
dist1, dist2 = chamfer_dist(points, points_reconstructed)
16+
loss = (torch.mean(dist1)) + (torch.mean(dist2))
17+
18+
19+
#...
20+
```

chamfer_distance/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .chamfer_distance import ChamferDistance

chamfer_distance/chamfer_distance.cpp

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#include <torch/torch.h>
2+
3+
// CUDA forward declarations
4+
int ChamferDistanceKernelLauncher(
5+
const int b, const int n,
6+
const float* xyz,
7+
const int m,
8+
const float* xyz2,
9+
float* result,
10+
int* result_i,
11+
float* result2,
12+
int* result2_i);
13+
14+
int ChamferDistanceGradKernelLauncher(
15+
const int b, const int n,
16+
const float* xyz1,
17+
const int m,
18+
const float* xyz2,
19+
const float* grad_dist1,
20+
const int* idx1,
21+
const float* grad_dist2,
22+
const int* idx2,
23+
float* grad_xyz1,
24+
float* grad_xyz2);
25+
26+
27+
void chamfer_distance_forward_cuda(
28+
const at::Tensor xyz1,
29+
const at::Tensor xyz2,
30+
const at::Tensor dist1,
31+
const at::Tensor dist2,
32+
const at::Tensor idx1,
33+
const at::Tensor idx2)
34+
{
35+
ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
36+
xyz2.size(1), xyz2.data<float>(),
37+
dist1.data<float>(), idx1.data<int>(),
38+
dist2.data<float>(), idx2.data<int>());
39+
}
40+
41+
void chamfer_distance_backward_cuda(
42+
const at::Tensor xyz1,
43+
const at::Tensor xyz2,
44+
at::Tensor gradxyz1,
45+
at::Tensor gradxyz2,
46+
at::Tensor graddist1,
47+
at::Tensor graddist2,
48+
at::Tensor idx1,
49+
at::Tensor idx2)
50+
{
51+
ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
52+
xyz2.size(1), xyz2.data<float>(),
53+
graddist1.data<float>(), idx1.data<int>(),
54+
graddist2.data<float>(), idx2.data<int>(),
55+
gradxyz1.data<float>(), gradxyz2.data<float>());
56+
}
57+
58+
59+
void nnsearch(
60+
const int b, const int n, const int m,
61+
const float* xyz1,
62+
const float* xyz2,
63+
float* dist,
64+
int* idx)
65+
{
66+
for (int i = 0; i < b; i++) {
67+
for (int j = 0; j < n; j++) {
68+
const float x1 = xyz1[(i*n+j)*3+0];
69+
const float y1 = xyz1[(i*n+j)*3+1];
70+
const float z1 = xyz1[(i*n+j)*3+2];
71+
double best = 0;
72+
int besti = 0;
73+
for (int k = 0; k < m; k++) {
74+
const float x2 = xyz2[(i*m+k)*3+0] - x1;
75+
const float y2 = xyz2[(i*m+k)*3+1] - y1;
76+
const float z2 = xyz2[(i*m+k)*3+2] - z1;
77+
const double d=x2*x2+y2*y2+z2*z2;
78+
if (k==0 || d < best){
79+
best = d;
80+
besti = k;
81+
}
82+
}
83+
dist[i*n+j] = best;
84+
idx[i*n+j] = besti;
85+
}
86+
}
87+
}
88+
89+
90+
void chamfer_distance_forward(
91+
const at::Tensor xyz1,
92+
const at::Tensor xyz2,
93+
const at::Tensor dist1,
94+
const at::Tensor dist2,
95+
const at::Tensor idx1,
96+
const at::Tensor idx2)
97+
{
98+
const int batchsize = xyz1.size(0);
99+
const int n = xyz1.size(1);
100+
const int m = xyz2.size(1);
101+
102+
const float* xyz1_data = xyz1.data<float>();
103+
const float* xyz2_data = xyz2.data<float>();
104+
float* dist1_data = dist1.data<float>();
105+
float* dist2_data = dist2.data<float>();
106+
int* idx1_data = idx1.data<int>();
107+
int* idx2_data = idx2.data<int>();
108+
109+
nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
110+
nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
111+
}
112+
113+
114+
void chamfer_distance_backward(
115+
const at::Tensor xyz1,
116+
const at::Tensor xyz2,
117+
at::Tensor gradxyz1,
118+
at::Tensor gradxyz2,
119+
at::Tensor graddist1,
120+
at::Tensor graddist2,
121+
at::Tensor idx1,
122+
at::Tensor idx2)
123+
{
124+
const int b = xyz1.size(0);
125+
const int n = xyz1.size(1);
126+
const int m = xyz2.size(1);
127+
128+
const float* xyz1_data = xyz1.data<float>();
129+
const float* xyz2_data = xyz2.data<float>();
130+
float* gradxyz1_data = gradxyz1.data<float>();
131+
float* gradxyz2_data = gradxyz2.data<float>();
132+
float* graddist1_data = graddist1.data<float>();
133+
float* graddist2_data = graddist2.data<float>();
134+
const int* idx1_data = idx1.data<int>();
135+
const int* idx2_data = idx2.data<int>();
136+
137+
for (int i = 0; i < b*n*3; i++)
138+
gradxyz1_data[i] = 0;
139+
for (int i = 0; i < b*m*3; i++)
140+
gradxyz2_data[i] = 0;
141+
for (int i = 0;i < b; i++) {
142+
for (int j = 0; j < n; j++) {
143+
const float x1 = xyz1_data[(i*n+j)*3+0];
144+
const float y1 = xyz1_data[(i*n+j)*3+1];
145+
const float z1 = xyz1_data[(i*n+j)*3+2];
146+
const int j2 = idx1_data[i*n+j];
147+
148+
const float x2 = xyz2_data[(i*m+j2)*3+0];
149+
const float y2 = xyz2_data[(i*m+j2)*3+1];
150+
const float z2 = xyz2_data[(i*m+j2)*3+2];
151+
const float g = graddist1_data[i*n+j]*2;
152+
153+
gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2);
154+
gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2);
155+
gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2);
156+
gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2));
157+
gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2));
158+
gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2));
159+
}
160+
for (int j = 0; j < m; j++) {
161+
const float x1 = xyz2_data[(i*m+j)*3+0];
162+
const float y1 = xyz2_data[(i*m+j)*3+1];
163+
const float z1 = xyz2_data[(i*m+j)*3+2];
164+
const int j2 = idx2_data[i*m+j];
165+
const float x2 = xyz1_data[(i*n+j2)*3+0];
166+
const float y2 = xyz1_data[(i*n+j2)*3+1];
167+
const float z2 = xyz1_data[(i*n+j2)*3+2];
168+
const float g = graddist2_data[i*m+j]*2;
169+
gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2);
170+
gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2);
171+
gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2);
172+
gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2));
173+
gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2));
174+
gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2));
175+
}
176+
}
177+
}
178+
179+
180+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
181+
m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
182+
m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
183+
m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
184+
m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
185+
}

0 commit comments

Comments
 (0)