|
| 1 | +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | + |
| 3 | +#include <torch/extension.h> |
| 4 | +#include <array> |
| 5 | +#include <limits> |
| 6 | +#include "utils/geometry_utils.h" |
| 7 | +#include "utils/vec3.h" |
| 8 | + |
| 9 | +// - We start with implementations of simple operations on points, edges and |
| 10 | +// faces. The hull of H points is a point if H=1, an edge if H=2, a face if H=3. |
| 11 | + |
| 12 | +template <typename T> |
| 13 | +vec3<T> ExtractPoint(const at::TensorAccessor<T, 1>& t) { |
| 14 | + return vec3(t[0], t[1], t[2]); |
| 15 | +} |
| 16 | + |
| 17 | +template <class Accessor> |
| 18 | +struct ExtractHullHelper { |
| 19 | + template <int H> |
| 20 | + static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, H> |
| 21 | + get(const Accessor& t); |
| 22 | + |
| 23 | + template <> |
| 24 | + static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, 1> |
| 25 | + get<1>(const Accessor& t) { |
| 26 | + return {ExtractPoint(t)}; |
| 27 | + } |
| 28 | + |
| 29 | + template <> |
| 30 | + static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, 2> |
| 31 | + get<2>(const Accessor& t) { |
| 32 | + return {ExtractPoint(t[0]), ExtractPoint(t[1])}; |
| 33 | + } |
| 34 | + |
| 35 | + template <> |
| 36 | + static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, 3> |
| 37 | + get<3>(const Accessor& t) { |
| 38 | + return {ExtractPoint(t[0]), ExtractPoint(t[1]), ExtractPoint(t[2])}; |
| 39 | + } |
| 40 | +}; |
| 41 | + |
| 42 | +template <int H, typename Accessor> |
| 43 | +std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, H> |
| 44 | +ExtractHull(const Accessor& t) { |
| 45 | + return ExtractHullHelper<Accessor>::template get<H>(t); |
| 46 | +} |
| 47 | + |
| 48 | +template <typename T> |
| 49 | +void IncrementPoint(at::TensorAccessor<T, 1>&& t, const vec3<T>& point) { |
| 50 | + t[0] += point.x; |
| 51 | + t[1] += point.y; |
| 52 | + t[2] += point.z; |
| 53 | +} |
| 54 | + |
| 55 | +// distance between the convex hull of A points and B points |
| 56 | +// this could be done in c++17 with tuple_cat and invoke |
| 57 | +template <typename T> |
| 58 | +T HullDistance( |
| 59 | + const std::array<vec3<T>, 1>& a, |
| 60 | + const std::array<vec3<T>, 2>& b) { |
| 61 | + using std::get; |
| 62 | + return PointLine3DistanceForward(get<0>(a), get<0>(b), get<1>(b)); |
| 63 | +} |
| 64 | +template <typename T> |
| 65 | +T HullDistance( |
| 66 | + const std::array<vec3<T>, 1>& a, |
| 67 | + const std::array<vec3<T>, 3>& b) { |
| 68 | + using std::get; |
| 69 | + return PointTriangle3DistanceForward( |
| 70 | + get<0>(a), get<0>(b), get<1>(b), get<2>(b)); |
| 71 | +} |
| 72 | +template <typename T> |
| 73 | +T HullDistance( |
| 74 | + const std::array<vec3<T>, 2>& a, |
| 75 | + const std::array<vec3<T>, 1>& b) { |
| 76 | + return HullDistance(b, a); |
| 77 | +} |
| 78 | +template <typename T> |
| 79 | +T HullDistance( |
| 80 | + const std::array<vec3<T>, 3>& a, |
| 81 | + const std::array<vec3<T>, 1>& b) { |
| 82 | + return HullDistance(b, a); |
| 83 | +} |
| 84 | + |
| 85 | +template <typename T> |
| 86 | +void HullHullDistanceBackward( |
| 87 | + const std::array<vec3<T>, 1>& a, |
| 88 | + const std::array<vec3<T>, 2>& b, |
| 89 | + T grad_dist, |
| 90 | + at::TensorAccessor<T, 1>&& grad_a, |
| 91 | + at::TensorAccessor<T, 2>&& grad_b) { |
| 92 | + using std::get; |
| 93 | + auto res = |
| 94 | + PointLine3DistanceBackward(get<0>(a), get<0>(b), get<1>(b), grad_dist); |
| 95 | + IncrementPoint(std::move(grad_a), get<0>(res)); |
| 96 | + IncrementPoint(grad_b[0], get<1>(res)); |
| 97 | + IncrementPoint(grad_b[1], get<2>(res)); |
| 98 | +} |
| 99 | +template <typename T> |
| 100 | +void HullHullDistanceBackward( |
| 101 | + const std::array<vec3<T>, 1>& a, |
| 102 | + const std::array<vec3<T>, 3>& b, |
| 103 | + T grad_dist, |
| 104 | + at::TensorAccessor<T, 1>&& grad_a, |
| 105 | + at::TensorAccessor<T, 2>&& grad_b) { |
| 106 | + using std::get; |
| 107 | + auto res = PointTriangle3DistanceBackward( |
| 108 | + get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist); |
| 109 | + IncrementPoint(std::move(grad_a), get<0>(res)); |
| 110 | + IncrementPoint(grad_b[0], get<1>(res)); |
| 111 | + IncrementPoint(grad_b[1], get<2>(res)); |
| 112 | + IncrementPoint(grad_b[2], get<3>(res)); |
| 113 | +} |
| 114 | +template <typename T> |
| 115 | +void HullHullDistanceBackward( |
| 116 | + const std::array<vec3<T>, 3>& a, |
| 117 | + const std::array<vec3<T>, 1>& b, |
| 118 | + T grad_dist, |
| 119 | + at::TensorAccessor<T, 2>&& grad_a, |
| 120 | + at::TensorAccessor<T, 1>&& grad_b) { |
| 121 | + return HullHullDistanceBackward( |
| 122 | + b, a, grad_dist, std::move(grad_b), std::move(grad_a)); |
| 123 | +} |
| 124 | +template <typename T> |
| 125 | +void HullHullDistanceBackward( |
| 126 | + const std::array<vec3<T>, 2>& a, |
| 127 | + const std::array<vec3<T>, 1>& b, |
| 128 | + T grad_dist, |
| 129 | + at::TensorAccessor<T, 2>&& grad_a, |
| 130 | + at::TensorAccessor<T, 1>&& grad_b) { |
| 131 | + return HullHullDistanceBackward( |
| 132 | + b, a, grad_dist, std::move(grad_b), std::move(grad_a)); |
| 133 | +} |
| 134 | + |
| 135 | +template <int H> |
| 136 | +void ValidateShape(const at::Tensor& as) { |
| 137 | + if (H == 1) { |
| 138 | + TORCH_CHECK(as.size(1) == 3); |
| 139 | + } else { |
| 140 | + TORCH_CHECK(as.size(2) == 3 && as.size(1) == H); |
| 141 | + } |
| 142 | +} |
| 143 | + |
| 144 | +// ----------- Here begins the implementation of each top-level |
| 145 | +// function using non-type template parameters to |
| 146 | +// implement all the cases in one go. ----------- // |
| 147 | + |
| 148 | +template <int H1, int H2> |
| 149 | +std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu( |
| 150 | + const at::Tensor& as, |
| 151 | + const at::Tensor& as_first_idx, |
| 152 | + const at::Tensor& bs, |
| 153 | + const at::Tensor& bs_first_idx) { |
| 154 | + const int64_t A_N = as.size(0); |
| 155 | + const int64_t B_N = bs.size(0); |
| 156 | + const int64_t BATCHES = as_first_idx.size(0); |
| 157 | + |
| 158 | + ValidateShape<H1>(as); |
| 159 | + ValidateShape<H2>(bs); |
| 160 | + |
| 161 | + TORCH_CHECK(bs_first_idx.size(0) == BATCHES); |
| 162 | + |
| 163 | + // clang-format off |
| 164 | + at::Tensor dists = at::zeros({A_N,}, as.options()); |
| 165 | + at::Tensor idxs = at::zeros({A_N,}, as_first_idx.options()); |
| 166 | + // clang-format on |
| 167 | + |
| 168 | + auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > (); |
| 169 | + auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > (); |
| 170 | + auto as_first_idx_a = as_first_idx.accessor<int64_t, 1>(); |
| 171 | + auto bs_first_idx_a = bs_first_idx.accessor<int64_t, 1>(); |
| 172 | + auto dists_a = dists.accessor<float, 1>(); |
| 173 | + auto idxs_a = idxs.accessor<int64_t, 1>(); |
| 174 | + int64_t a_batch_end = 0; |
| 175 | + int64_t b_batch_start = 0, b_batch_end = 0; |
| 176 | + int64_t batch_idx = 0; |
| 177 | + for (int64_t a_n = 0; a_n < A_N; ++a_n) { |
| 178 | + if (a_n == a_batch_end) { |
| 179 | + ++batch_idx; |
| 180 | + b_batch_start = b_batch_end; |
| 181 | + if (batch_idx == BATCHES) { |
| 182 | + a_batch_end = std::numeric_limits<int64_t>::max(); |
| 183 | + b_batch_end = B_N; |
| 184 | + } else { |
| 185 | + a_batch_end = as_first_idx_a[batch_idx]; |
| 186 | + b_batch_end = bs_first_idx_a[batch_idx]; |
| 187 | + } |
| 188 | + } |
| 189 | + float min_dist = std::numeric_limits<float>::max(); |
| 190 | + size_t min_idx = 0; |
| 191 | + auto a = ExtractHull<H1>(as_a[a_n]); |
| 192 | + for (int64_t b_n = b_batch_start; b_n < b_batch_end; ++b_n) { |
| 193 | + float dist = HullDistance(a, ExtractHull<H2>(bs_a[b_n])); |
| 194 | + if (dist <= min_dist) { |
| 195 | + min_dist = dist; |
| 196 | + min_idx = b_n; |
| 197 | + } |
| 198 | + } |
| 199 | + dists_a[a_n] = min_dist; |
| 200 | + idxs_a[a_n] = min_idx; |
| 201 | + } |
| 202 | + |
| 203 | + return std::make_tuple(dists, idxs); |
| 204 | +} |
| 205 | + |
| 206 | +template <int H1, int H2> |
| 207 | +std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu( |
| 208 | + const at::Tensor& as, |
| 209 | + const at::Tensor& bs, |
| 210 | + const at::Tensor& idx_bs, |
| 211 | + const at::Tensor& grad_dists) { |
| 212 | + const int64_t A_N = as.size(0); |
| 213 | + |
| 214 | + TORCH_CHECK(idx_bs.size(0) == A_N); |
| 215 | + TORCH_CHECK(grad_dists.size(0) == A_N); |
| 216 | + ValidateShape<H1>(as); |
| 217 | + ValidateShape<H2>(bs); |
| 218 | + |
| 219 | + at::Tensor grad_as = at::zeros_like(as); |
| 220 | + at::Tensor grad_bs = at::zeros_like(bs); |
| 221 | + |
| 222 | + auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > (); |
| 223 | + auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > (); |
| 224 | + auto grad_as_a = grad_as.accessor < float, H1 == 1 ? 2 : 3 > (); |
| 225 | + auto grad_bs_a = grad_bs.accessor < float, H2 == 1 ? 2 : 3 > (); |
| 226 | + auto idx_bs_a = idx_bs.accessor<int64_t, 1>(); |
| 227 | + auto grad_dists_a = grad_dists.accessor<float, 1>(); |
| 228 | + |
| 229 | + for (int64_t a_n = 0; a_n < A_N; ++a_n) { |
| 230 | + auto a = ExtractHull<H1>(as_a[a_n]); |
| 231 | + auto b = ExtractHull<H2>(bs_a[idx_bs_a[a_n]]); |
| 232 | + HullHullDistanceBackward( |
| 233 | + a, b, grad_dists_a[a_n], grad_as_a[a_n], grad_bs_a[idx_bs_a[a_n]]); |
| 234 | + } |
| 235 | + return std::make_tuple(grad_as, grad_bs); |
| 236 | +} |
| 237 | + |
| 238 | +template <int H> |
| 239 | +torch::Tensor PointHullArrayDistanceForwardCpu( |
| 240 | + const torch::Tensor& points, |
| 241 | + const torch::Tensor& bs) { |
| 242 | + const int64_t P = points.size(0); |
| 243 | + const int64_t B_N = bs.size(0); |
| 244 | + |
| 245 | + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); |
| 246 | + ValidateShape<H>(bs); |
| 247 | + |
| 248 | + at::Tensor dists = at::zeros({P, B_N}, points.options()); |
| 249 | + auto points_a = points.accessor<float, 2>(); |
| 250 | + auto bs_a = bs.accessor<float, 3>(); |
| 251 | + auto dists_a = dists.accessor<float, 2>(); |
| 252 | + for (int64_t p = 0; p < P; ++p) { |
| 253 | + auto point = ExtractHull<1>(points_a[p]); |
| 254 | + auto dest = dists_a[p]; |
| 255 | + for (int64_t b_n = 0; b_n < B_N; ++b_n) { |
| 256 | + auto b = ExtractHull<H>(bs_a[b_n]); |
| 257 | + dest[b_n] = HullDistance(point, b); |
| 258 | + } |
| 259 | + } |
| 260 | + return dists; |
| 261 | +} |
| 262 | + |
| 263 | +template <int H> |
| 264 | +std::tuple<at::Tensor, at::Tensor> PointHullArrayDistanceBackwardCpu( |
| 265 | + const at::Tensor& points, |
| 266 | + const at::Tensor& bs, |
| 267 | + const at::Tensor& grad_dists) { |
| 268 | + const int64_t P = points.size(0); |
| 269 | + const int64_t B_N = bs.size(0); |
| 270 | + |
| 271 | + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); |
| 272 | + ValidateShape<H>(bs); |
| 273 | + TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == B_N)); |
| 274 | + |
| 275 | + at::Tensor grad_points = at::zeros({P, 3}, points.options()); |
| 276 | + at::Tensor grad_bs = at::zeros({B_N, H, 3}, bs.options()); |
| 277 | + |
| 278 | + auto points_a = points.accessor<float, 2>(); |
| 279 | + auto bs_a = bs.accessor<float, 3>(); |
| 280 | + auto grad_dists_a = grad_dists.accessor<float, 2>(); |
| 281 | + auto grad_points_a = grad_points.accessor<float, 2>(); |
| 282 | + auto grad_bs_a = grad_bs.accessor<float, 3>(); |
| 283 | + for (int64_t p = 0; p < P; ++p) { |
| 284 | + auto point = ExtractHull<1>(points_a[p]); |
| 285 | + auto grad_point = grad_points_a[p]; |
| 286 | + auto grad_dist = grad_dists_a[p]; |
| 287 | + for (int64_t b_n = 0; b_n < B_N; ++b_n) { |
| 288 | + auto b = ExtractHull<H>(bs_a[b_n]); |
| 289 | + HullHullDistanceBackward( |
| 290 | + point, b, grad_dist[b_n], std::move(grad_point), grad_bs_a[b_n]); |
| 291 | + } |
| 292 | + } |
| 293 | + return std::make_tuple(grad_points, grad_bs); |
| 294 | +} |
| 295 | + |
| 296 | +// ---------- Here begin the exported functions ------------ // |
| 297 | + |
| 298 | +std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCpu( |
| 299 | + const torch::Tensor& points, |
| 300 | + const torch::Tensor& points_first_idx, |
| 301 | + const torch::Tensor& tris, |
| 302 | + const torch::Tensor& tris_first_idx) { |
| 303 | + return HullHullDistanceForwardCpu<1, 3>( |
| 304 | + points, points_first_idx, tris, tris_first_idx); |
| 305 | +} |
| 306 | + |
| 307 | +std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCpu( |
| 308 | + const torch::Tensor& points, |
| 309 | + const torch::Tensor& tris, |
| 310 | + const torch::Tensor& idx_points, |
| 311 | + const torch::Tensor& grad_dists) { |
| 312 | + return HullHullDistanceBackwardCpu<1, 3>( |
| 313 | + points, tris, idx_points, grad_dists); |
| 314 | +} |
| 315 | + |
| 316 | +std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCpu( |
| 317 | + const torch::Tensor& points, |
| 318 | + const torch::Tensor& points_first_idx, |
| 319 | + const torch::Tensor& tris, |
| 320 | + const torch::Tensor& tris_first_idx) { |
| 321 | + return HullHullDistanceForwardCpu<3, 1>( |
| 322 | + tris, tris_first_idx, points, points_first_idx); |
| 323 | +} |
| 324 | + |
| 325 | +std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCpu( |
| 326 | + const torch::Tensor& points, |
| 327 | + const torch::Tensor& tris, |
| 328 | + const torch::Tensor& idx_tris, |
| 329 | + const torch::Tensor& grad_dists) { |
| 330 | + auto res = |
| 331 | + HullHullDistanceBackwardCpu<3, 1>(tris, points, idx_tris, grad_dists); |
| 332 | + return std::make_tuple(std::get<1>(res), std::get<0>(res)); |
| 333 | +} |
| 334 | + |
| 335 | +torch::Tensor PointEdgeArrayDistanceForwardCpu( |
| 336 | + const torch::Tensor& points, |
| 337 | + const torch::Tensor& segms) { |
| 338 | + return PointHullArrayDistanceForwardCpu<2>(points, segms); |
| 339 | +} |
| 340 | + |
| 341 | +std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCpu( |
| 342 | + const at::Tensor& points, |
| 343 | + const at::Tensor& tris, |
| 344 | + const at::Tensor& grad_dists) { |
| 345 | + return PointHullArrayDistanceBackwardCpu<3>(points, tris, grad_dists); |
| 346 | +} |
| 347 | + |
| 348 | +torch::Tensor PointFaceArrayDistanceForwardCpu( |
| 349 | + const torch::Tensor& points, |
| 350 | + const torch::Tensor& tris) { |
| 351 | + return PointHullArrayDistanceForwardCpu<3>(points, tris); |
| 352 | +} |
| 353 | + |
| 354 | +std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCpu( |
| 355 | + const at::Tensor& points, |
| 356 | + const at::Tensor& segms, |
| 357 | + const at::Tensor& grad_dists) { |
| 358 | + return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists); |
| 359 | +} |
| 360 | + |
| 361 | +std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu( |
| 362 | + const torch::Tensor& points, |
| 363 | + const torch::Tensor& points_first_idx, |
| 364 | + const torch::Tensor& segms, |
| 365 | + const torch::Tensor& segms_first_idx, |
| 366 | + const int64_t /*max_points*/) { |
| 367 | + return HullHullDistanceForwardCpu<1, 2>( |
| 368 | + points, points_first_idx, segms, segms_first_idx); |
| 369 | +} |
| 370 | + |
| 371 | +std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu( |
| 372 | + const torch::Tensor& points, |
| 373 | + const torch::Tensor& segms, |
| 374 | + const torch::Tensor& idx_points, |
| 375 | + const torch::Tensor& grad_dists) { |
| 376 | + return HullHullDistanceBackwardCpu<1, 2>( |
| 377 | + points, segms, idx_points, grad_dists); |
| 378 | +} |
| 379 | + |
| 380 | +std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu( |
| 381 | + const torch::Tensor& points, |
| 382 | + const torch::Tensor& points_first_idx, |
| 383 | + const torch::Tensor& segms, |
| 384 | + const torch::Tensor& segms_first_idx, |
| 385 | + const int64_t /*max_segms*/) { |
| 386 | + return HullHullDistanceForwardCpu<2, 1>( |
| 387 | + segms, segms_first_idx, points, points_first_idx); |
| 388 | +} |
| 389 | + |
| 390 | +std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu( |
| 391 | + const torch::Tensor& points, |
| 392 | + const torch::Tensor& segms, |
| 393 | + const torch::Tensor& idx_segms, |
| 394 | + const torch::Tensor& grad_dists) { |
| 395 | + auto res = |
| 396 | + HullHullDistanceBackwardCpu<2, 1>(segms, points, idx_segms, grad_dists); |
| 397 | + return std::make_tuple(std::get<1>(res), std::get<0>(res)); |
| 398 | +} |
0 commit comments