Skip to content

Commit 74659ae

Browse files
bottlerfacebook-github-bot
authored andcommitted
CPU implementation for point_mesh functions
Summary: point_mesh functions were missing CPU implementations. The indices returned are not always matching, possibly due to numerical instability. Reviewed By: gkioxari Differential Revision: D21594264 fbshipit-source-id: 3016930e2a9a0f3cd8b3ac4c94a92c9411c0989d
1 parent 7f1e63a commit 74659ae

File tree

6 files changed

+878
-31
lines changed

6 files changed

+878
-31
lines changed
+398
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,398 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#include <torch/extension.h>
4+
#include <array>
5+
#include <limits>
6+
#include "utils/geometry_utils.h"
7+
#include "utils/vec3.h"
8+
9+
// - We start with implementations of simple operations on points, edges and
10+
// faces. The hull of H points is a point if H=1, an edge if H=2, a face if H=3.
11+
12+
template <typename T>
13+
vec3<T> ExtractPoint(const at::TensorAccessor<T, 1>& t) {
14+
return vec3(t[0], t[1], t[2]);
15+
}
16+
17+
template <class Accessor>
18+
struct ExtractHullHelper {
19+
template <int H>
20+
static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, H>
21+
get(const Accessor& t);
22+
23+
template <>
24+
static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, 1>
25+
get<1>(const Accessor& t) {
26+
return {ExtractPoint(t)};
27+
}
28+
29+
template <>
30+
static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, 2>
31+
get<2>(const Accessor& t) {
32+
return {ExtractPoint(t[0]), ExtractPoint(t[1])};
33+
}
34+
35+
template <>
36+
static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, 3>
37+
get<3>(const Accessor& t) {
38+
return {ExtractPoint(t[0]), ExtractPoint(t[1]), ExtractPoint(t[2])};
39+
}
40+
};
41+
42+
template <int H, typename Accessor>
43+
std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, H>
44+
ExtractHull(const Accessor& t) {
45+
return ExtractHullHelper<Accessor>::template get<H>(t);
46+
}
47+
48+
template <typename T>
49+
void IncrementPoint(at::TensorAccessor<T, 1>&& t, const vec3<T>& point) {
50+
t[0] += point.x;
51+
t[1] += point.y;
52+
t[2] += point.z;
53+
}
54+
55+
// distance between the convex hull of A points and B points
56+
// this could be done in c++17 with tuple_cat and invoke
57+
template <typename T>
58+
T HullDistance(
59+
const std::array<vec3<T>, 1>& a,
60+
const std::array<vec3<T>, 2>& b) {
61+
using std::get;
62+
return PointLine3DistanceForward(get<0>(a), get<0>(b), get<1>(b));
63+
}
64+
template <typename T>
65+
T HullDistance(
66+
const std::array<vec3<T>, 1>& a,
67+
const std::array<vec3<T>, 3>& b) {
68+
using std::get;
69+
return PointTriangle3DistanceForward(
70+
get<0>(a), get<0>(b), get<1>(b), get<2>(b));
71+
}
72+
template <typename T>
73+
T HullDistance(
74+
const std::array<vec3<T>, 2>& a,
75+
const std::array<vec3<T>, 1>& b) {
76+
return HullDistance(b, a);
77+
}
78+
template <typename T>
79+
T HullDistance(
80+
const std::array<vec3<T>, 3>& a,
81+
const std::array<vec3<T>, 1>& b) {
82+
return HullDistance(b, a);
83+
}
84+
85+
template <typename T>
86+
void HullHullDistanceBackward(
87+
const std::array<vec3<T>, 1>& a,
88+
const std::array<vec3<T>, 2>& b,
89+
T grad_dist,
90+
at::TensorAccessor<T, 1>&& grad_a,
91+
at::TensorAccessor<T, 2>&& grad_b) {
92+
using std::get;
93+
auto res =
94+
PointLine3DistanceBackward(get<0>(a), get<0>(b), get<1>(b), grad_dist);
95+
IncrementPoint(std::move(grad_a), get<0>(res));
96+
IncrementPoint(grad_b[0], get<1>(res));
97+
IncrementPoint(grad_b[1], get<2>(res));
98+
}
99+
template <typename T>
100+
void HullHullDistanceBackward(
101+
const std::array<vec3<T>, 1>& a,
102+
const std::array<vec3<T>, 3>& b,
103+
T grad_dist,
104+
at::TensorAccessor<T, 1>&& grad_a,
105+
at::TensorAccessor<T, 2>&& grad_b) {
106+
using std::get;
107+
auto res = PointTriangle3DistanceBackward(
108+
get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist);
109+
IncrementPoint(std::move(grad_a), get<0>(res));
110+
IncrementPoint(grad_b[0], get<1>(res));
111+
IncrementPoint(grad_b[1], get<2>(res));
112+
IncrementPoint(grad_b[2], get<3>(res));
113+
}
114+
template <typename T>
115+
void HullHullDistanceBackward(
116+
const std::array<vec3<T>, 3>& a,
117+
const std::array<vec3<T>, 1>& b,
118+
T grad_dist,
119+
at::TensorAccessor<T, 2>&& grad_a,
120+
at::TensorAccessor<T, 1>&& grad_b) {
121+
return HullHullDistanceBackward(
122+
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
123+
}
124+
template <typename T>
125+
void HullHullDistanceBackward(
126+
const std::array<vec3<T>, 2>& a,
127+
const std::array<vec3<T>, 1>& b,
128+
T grad_dist,
129+
at::TensorAccessor<T, 2>&& grad_a,
130+
at::TensorAccessor<T, 1>&& grad_b) {
131+
return HullHullDistanceBackward(
132+
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
133+
}
134+
135+
template <int H>
136+
void ValidateShape(const at::Tensor& as) {
137+
if (H == 1) {
138+
TORCH_CHECK(as.size(1) == 3);
139+
} else {
140+
TORCH_CHECK(as.size(2) == 3 && as.size(1) == H);
141+
}
142+
}
143+
144+
// ----------- Here begins the implementation of each top-level
145+
// function using non-type template parameters to
146+
// implement all the cases in one go. ----------- //
147+
148+
template <int H1, int H2>
149+
std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
150+
const at::Tensor& as,
151+
const at::Tensor& as_first_idx,
152+
const at::Tensor& bs,
153+
const at::Tensor& bs_first_idx) {
154+
const int64_t A_N = as.size(0);
155+
const int64_t B_N = bs.size(0);
156+
const int64_t BATCHES = as_first_idx.size(0);
157+
158+
ValidateShape<H1>(as);
159+
ValidateShape<H2>(bs);
160+
161+
TORCH_CHECK(bs_first_idx.size(0) == BATCHES);
162+
163+
// clang-format off
164+
at::Tensor dists = at::zeros({A_N,}, as.options());
165+
at::Tensor idxs = at::zeros({A_N,}, as_first_idx.options());
166+
// clang-format on
167+
168+
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
169+
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
170+
auto as_first_idx_a = as_first_idx.accessor<int64_t, 1>();
171+
auto bs_first_idx_a = bs_first_idx.accessor<int64_t, 1>();
172+
auto dists_a = dists.accessor<float, 1>();
173+
auto idxs_a = idxs.accessor<int64_t, 1>();
174+
int64_t a_batch_end = 0;
175+
int64_t b_batch_start = 0, b_batch_end = 0;
176+
int64_t batch_idx = 0;
177+
for (int64_t a_n = 0; a_n < A_N; ++a_n) {
178+
if (a_n == a_batch_end) {
179+
++batch_idx;
180+
b_batch_start = b_batch_end;
181+
if (batch_idx == BATCHES) {
182+
a_batch_end = std::numeric_limits<int64_t>::max();
183+
b_batch_end = B_N;
184+
} else {
185+
a_batch_end = as_first_idx_a[batch_idx];
186+
b_batch_end = bs_first_idx_a[batch_idx];
187+
}
188+
}
189+
float min_dist = std::numeric_limits<float>::max();
190+
size_t min_idx = 0;
191+
auto a = ExtractHull<H1>(as_a[a_n]);
192+
for (int64_t b_n = b_batch_start; b_n < b_batch_end; ++b_n) {
193+
float dist = HullDistance(a, ExtractHull<H2>(bs_a[b_n]));
194+
if (dist <= min_dist) {
195+
min_dist = dist;
196+
min_idx = b_n;
197+
}
198+
}
199+
dists_a[a_n] = min_dist;
200+
idxs_a[a_n] = min_idx;
201+
}
202+
203+
return std::make_tuple(dists, idxs);
204+
}
205+
206+
template <int H1, int H2>
207+
std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
208+
const at::Tensor& as,
209+
const at::Tensor& bs,
210+
const at::Tensor& idx_bs,
211+
const at::Tensor& grad_dists) {
212+
const int64_t A_N = as.size(0);
213+
214+
TORCH_CHECK(idx_bs.size(0) == A_N);
215+
TORCH_CHECK(grad_dists.size(0) == A_N);
216+
ValidateShape<H1>(as);
217+
ValidateShape<H2>(bs);
218+
219+
at::Tensor grad_as = at::zeros_like(as);
220+
at::Tensor grad_bs = at::zeros_like(bs);
221+
222+
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
223+
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
224+
auto grad_as_a = grad_as.accessor < float, H1 == 1 ? 2 : 3 > ();
225+
auto grad_bs_a = grad_bs.accessor < float, H2 == 1 ? 2 : 3 > ();
226+
auto idx_bs_a = idx_bs.accessor<int64_t, 1>();
227+
auto grad_dists_a = grad_dists.accessor<float, 1>();
228+
229+
for (int64_t a_n = 0; a_n < A_N; ++a_n) {
230+
auto a = ExtractHull<H1>(as_a[a_n]);
231+
auto b = ExtractHull<H2>(bs_a[idx_bs_a[a_n]]);
232+
HullHullDistanceBackward(
233+
a, b, grad_dists_a[a_n], grad_as_a[a_n], grad_bs_a[idx_bs_a[a_n]]);
234+
}
235+
return std::make_tuple(grad_as, grad_bs);
236+
}
237+
238+
template <int H>
239+
torch::Tensor PointHullArrayDistanceForwardCpu(
240+
const torch::Tensor& points,
241+
const torch::Tensor& bs) {
242+
const int64_t P = points.size(0);
243+
const int64_t B_N = bs.size(0);
244+
245+
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
246+
ValidateShape<H>(bs);
247+
248+
at::Tensor dists = at::zeros({P, B_N}, points.options());
249+
auto points_a = points.accessor<float, 2>();
250+
auto bs_a = bs.accessor<float, 3>();
251+
auto dists_a = dists.accessor<float, 2>();
252+
for (int64_t p = 0; p < P; ++p) {
253+
auto point = ExtractHull<1>(points_a[p]);
254+
auto dest = dists_a[p];
255+
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
256+
auto b = ExtractHull<H>(bs_a[b_n]);
257+
dest[b_n] = HullDistance(point, b);
258+
}
259+
}
260+
return dists;
261+
}
262+
263+
template <int H>
264+
std::tuple<at::Tensor, at::Tensor> PointHullArrayDistanceBackwardCpu(
265+
const at::Tensor& points,
266+
const at::Tensor& bs,
267+
const at::Tensor& grad_dists) {
268+
const int64_t P = points.size(0);
269+
const int64_t B_N = bs.size(0);
270+
271+
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
272+
ValidateShape<H>(bs);
273+
TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == B_N));
274+
275+
at::Tensor grad_points = at::zeros({P, 3}, points.options());
276+
at::Tensor grad_bs = at::zeros({B_N, H, 3}, bs.options());
277+
278+
auto points_a = points.accessor<float, 2>();
279+
auto bs_a = bs.accessor<float, 3>();
280+
auto grad_dists_a = grad_dists.accessor<float, 2>();
281+
auto grad_points_a = grad_points.accessor<float, 2>();
282+
auto grad_bs_a = grad_bs.accessor<float, 3>();
283+
for (int64_t p = 0; p < P; ++p) {
284+
auto point = ExtractHull<1>(points_a[p]);
285+
auto grad_point = grad_points_a[p];
286+
auto grad_dist = grad_dists_a[p];
287+
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
288+
auto b = ExtractHull<H>(bs_a[b_n]);
289+
HullHullDistanceBackward(
290+
point, b, grad_dist[b_n], std::move(grad_point), grad_bs_a[b_n]);
291+
}
292+
}
293+
return std::make_tuple(grad_points, grad_bs);
294+
}
295+
296+
// ---------- Here begin the exported functions ------------ //
297+
298+
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCpu(
299+
const torch::Tensor& points,
300+
const torch::Tensor& points_first_idx,
301+
const torch::Tensor& tris,
302+
const torch::Tensor& tris_first_idx) {
303+
return HullHullDistanceForwardCpu<1, 3>(
304+
points, points_first_idx, tris, tris_first_idx);
305+
}
306+
307+
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCpu(
308+
const torch::Tensor& points,
309+
const torch::Tensor& tris,
310+
const torch::Tensor& idx_points,
311+
const torch::Tensor& grad_dists) {
312+
return HullHullDistanceBackwardCpu<1, 3>(
313+
points, tris, idx_points, grad_dists);
314+
}
315+
316+
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCpu(
317+
const torch::Tensor& points,
318+
const torch::Tensor& points_first_idx,
319+
const torch::Tensor& tris,
320+
const torch::Tensor& tris_first_idx) {
321+
return HullHullDistanceForwardCpu<3, 1>(
322+
tris, tris_first_idx, points, points_first_idx);
323+
}
324+
325+
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCpu(
326+
const torch::Tensor& points,
327+
const torch::Tensor& tris,
328+
const torch::Tensor& idx_tris,
329+
const torch::Tensor& grad_dists) {
330+
auto res =
331+
HullHullDistanceBackwardCpu<3, 1>(tris, points, idx_tris, grad_dists);
332+
return std::make_tuple(std::get<1>(res), std::get<0>(res));
333+
}
334+
335+
torch::Tensor PointEdgeArrayDistanceForwardCpu(
336+
const torch::Tensor& points,
337+
const torch::Tensor& segms) {
338+
return PointHullArrayDistanceForwardCpu<2>(points, segms);
339+
}
340+
341+
std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCpu(
342+
const at::Tensor& points,
343+
const at::Tensor& tris,
344+
const at::Tensor& grad_dists) {
345+
return PointHullArrayDistanceBackwardCpu<3>(points, tris, grad_dists);
346+
}
347+
348+
torch::Tensor PointFaceArrayDistanceForwardCpu(
349+
const torch::Tensor& points,
350+
const torch::Tensor& tris) {
351+
return PointHullArrayDistanceForwardCpu<3>(points, tris);
352+
}
353+
354+
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCpu(
355+
const at::Tensor& points,
356+
const at::Tensor& segms,
357+
const at::Tensor& grad_dists) {
358+
return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists);
359+
}
360+
361+
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
362+
const torch::Tensor& points,
363+
const torch::Tensor& points_first_idx,
364+
const torch::Tensor& segms,
365+
const torch::Tensor& segms_first_idx,
366+
const int64_t /*max_points*/) {
367+
return HullHullDistanceForwardCpu<1, 2>(
368+
points, points_first_idx, segms, segms_first_idx);
369+
}
370+
371+
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
372+
const torch::Tensor& points,
373+
const torch::Tensor& segms,
374+
const torch::Tensor& idx_points,
375+
const torch::Tensor& grad_dists) {
376+
return HullHullDistanceBackwardCpu<1, 2>(
377+
points, segms, idx_points, grad_dists);
378+
}
379+
380+
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
381+
const torch::Tensor& points,
382+
const torch::Tensor& points_first_idx,
383+
const torch::Tensor& segms,
384+
const torch::Tensor& segms_first_idx,
385+
const int64_t /*max_segms*/) {
386+
return HullHullDistanceForwardCpu<2, 1>(
387+
segms, segms_first_idx, points, points_first_idx);
388+
}
389+
390+
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
391+
const torch::Tensor& points,
392+
const torch::Tensor& segms,
393+
const torch::Tensor& idx_segms,
394+
const torch::Tensor& grad_dists) {
395+
auto res =
396+
HullHullDistanceBackwardCpu<2, 1>(segms, points, idx_segms, grad_dists);
397+
return std::make_tuple(std::get<1>(res), std::get<0>(res));
398+
}

0 commit comments

Comments
 (0)