Skip to content

Commit a6508ac

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Fix: Pointclouds.inside_box reducing over spatial dimensions.
Summary: As subj. Tests corrected accordingly. Also changed the test to provide a bit better diagnostics. Reviewed By: bottler Differential Revision: D32879498 fbshipit-source-id: 0a852e4a13dcb4ca3e54d71c6b263c5d2eeaf4eb
1 parent d9f7095 commit a6508ac

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

pytorch3d/structures/pointclouds.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,5 +1176,5 @@ def inside_box(self, box):
11761176
]
11771177
box = torch.cat(box, 0)
11781178

1179-
idx = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
1180-
return idx
1179+
coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
1180+
return coord_inside.all(dim=-1)

tests/common_testing.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,17 @@ def assertClose(
163163
if close:
164164
return
165165

166-
diff = backend.abs(input + 0.0 - other)
167-
ratio = diff / backend.abs(other)
166+
# handle bool case
167+
if backend == torch and input.dtype == torch.bool:
168+
diff = (input != other).float()
169+
ratio = diff
170+
if backend == np and input.dtype == bool:
171+
diff = (input != other).astype(float)
172+
ratio = diff
173+
else:
174+
diff = backend.abs(input + 0.0 - other)
175+
ratio = diff / backend.abs(other)
176+
168177
try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0))
169178
if try_relative.all():
170179
if backend == np:

tests/test_pointclouds.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,9 @@ def test_update_padded(self):
976976

977977
def test_inside_box(self):
978978
def inside_box_naive(cloud, box_min, box_max):
979-
return (cloud >= box_min.view(1, 3)) * (cloud <= box_max.view(1, 3))
979+
return ((cloud >= box_min.view(1, 3)) * (cloud <= box_max.view(1, 3))).all(
980+
dim=-1
981+
)
980982

981983
N, P, C = 5, 100, 4
982984

@@ -994,7 +996,7 @@ def inside_box_naive(cloud, box_min, box_max):
994996
for i, cloud in enumerate(clouds.points_list()):
995997
within_box_naive.append(inside_box_naive(cloud, box[i, 0], box[i, 1]))
996998
within_box_naive = torch.cat(within_box_naive, 0)
997-
self.assertTrue(within_box.eq(within_box_naive).all())
999+
self.assertClose(within_box, within_box_naive)
9981000

9991001
# box of shape 2x3
10001002
box2 = box[0, :]
@@ -1005,13 +1007,13 @@ def inside_box_naive(cloud, box_min, box_max):
10051007
for cloud in clouds.points_list():
10061008
within_box_naive2.append(inside_box_naive(cloud, box2[0], box2[1]))
10071009
within_box_naive2 = torch.cat(within_box_naive2, 0)
1008-
self.assertTrue(within_box2.eq(within_box_naive2).all())
1010+
self.assertClose(within_box2, within_box_naive2)
10091011

10101012
# box of shape 1x2x3
10111013
box3 = box2.expand(1, 2, 3)
10121014

10131015
within_box3 = clouds.inside_box(box3)
1014-
self.assertTrue(within_box2.eq(within_box3).all())
1016+
self.assertClose(within_box2, within_box3)
10151017

10161018
# invalid box
10171019
invalid_box = torch.cat(

0 commit comments

Comments
 (0)