Skip to content

Commit 75dd149

Browse files
committed
Extend FPS with an extra ptr argument
1 parent 32bee64 commit 75dd149

File tree

2 files changed

+69
-12
lines changed

2 files changed

+69
-12
lines changed

test/test_fps.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def test_fps(dtype, device):
2525
[+2, -2],
2626
], dtype, device)
2727
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
28+
ptr = [0, 4, 8]
29+
ptr_tensor = tensor(ptr, torch.long, device)
2830

2931
out = fps(x, batch, random_start=False)
3032
assert out.tolist() == [0, 2, 4, 6]
@@ -36,6 +38,12 @@ def test_fps(dtype, device):
3638
random_start=False)
3739
assert out.tolist() == [0, 2, 4, 6]
3840

41+
out = fps(x, ptr=ptr, ratio=0.5, random_start=False)
42+
assert out.tolist() == [0, 2, 4, 6]
43+
44+
out = fps(x, ptr=ptr_tensor, ratio=0.5, random_start=False)
45+
assert out.tolist() == [0, 2, 4, 6]
46+
3947
out = fps(x, batch, ratio=torch.tensor([0.5, 0.5], device=device),
4048
random_start=False)
4149
assert out.tolist() == [0, 2, 4, 6]
@@ -66,3 +74,6 @@ def test_random_fps(device):
6674
batch = torch.cat([batch_1, batch_2])
6775
idx = fps(pos, batch, ratio=0.5)
6876
assert idx.min() >= 0 and idx.max() < 2 * N
77+
ptr = [0, N, 2*N]
78+
idx = fps(pos, ptr=ptr, ratio=0.5)
79+
assert idx.min() >= 0 and idx.max() < 2 * N

torch_cluster/fps.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,46 @@
1-
from typing import Optional, Union
1+
from typing import Optional, Union, List
22

33
import torch
44
from torch import Tensor
55

66

77
@torch.jit._overload # noqa
8-
def fps(src, batch, ratio, random_start, batch_size): # noqa
9-
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int]) -> Tensor # noqa
8+
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
9+
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
1010
pass # pragma: no cover
1111

1212

1313
@torch.jit._overload # noqa
14-
def fps(src, batch, ratio, random_start, batch_size): # noqa
15-
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int]) -> Tensor # noqa
14+
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
15+
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
1616
pass # pragma: no cover
1717

1818

19-
def fps( # noqa
19+
@torch.jit._overload # noqa
20+
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
21+
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
22+
pass # pragma: no cover
23+
24+
25+
@torch.jit._overload # noqa
26+
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
27+
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
28+
pass # pragma: no cover
29+
30+
31+
@torch.jit._overload # noqa
32+
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
33+
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
34+
pass # pragma: no cover
35+
36+
37+
def fps_ptr_tensor(
2038
src: torch.Tensor,
2139
batch: Optional[Tensor] = None,
22-
ratio: Optional[Union[torch.Tensor, float]] = None,
40+
ratio: Optional[Union[Tensor, float]] = None,
2341
random_start: bool = True,
2442
batch_size: Optional[int] = None,
43+
ptr: Optional[Union[Tensor, List[int]]] = None,
2544
):
2645
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
2746
Learning on Point Sets in a Metric Space"
@@ -40,6 +59,9 @@ def fps( # noqa
4059
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
4160
batch_size (int, optional): The number of examples :math:`B`.
4261
Automatically calculated if not given. (default: :obj:`None`)
62+
ptr (LongTensor or list of ints): Ptr vector, which defines nodes
63+
ranges for consecutive batches, e.g. batch=[0,0,1,1,1,2] translates
64+
to ptr=[0,2,5,6].
4365
4466
:rtype: :class:`LongTensor`
4567
@@ -52,7 +74,6 @@ def fps( # noqa
5274
batch = torch.tensor([0, 0, 0, 0])
5375
index = fps(src, batch, ratio=0.5)
5476
"""
55-
5677
r: Optional[Tensor] = None
5778
if ratio is None:
5879
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
@@ -61,6 +82,7 @@ def fps( # noqa
6182
else:
6283
r = ratio
6384
assert r is not None
85+
assert batch is None or ptr is None
6486

6587
if batch is not None:
6688
assert src.size(0) == batch.numel()
@@ -70,9 +92,33 @@ def fps( # noqa
7092
deg = src.new_zeros(batch_size, dtype=torch.long)
7193
deg.scatter_add_(0, batch, torch.ones_like(batch))
7294

73-
ptr = deg.new_zeros(batch_size + 1)
74-
torch.cumsum(deg, 0, out=ptr[1:])
95+
p = deg.new_zeros(batch_size + 1)
96+
torch.cumsum(deg, 0, out=p[1:])
7597
else:
76-
ptr = torch.tensor([0, src.size(0)], device=src.device)
98+
if ptr is None:
99+
p = torch.tensor([0, src.size(0)], device=src.device)
100+
else:
101+
if isinstance(ptr, Tensor):
102+
p = ptr
103+
else:
104+
p = torch.tensor(ptr, device=src.device)
105+
106+
return torch.ops.torch_cluster.fps(src, p, r, random_start)
107+
108+
109+
def fps_ptr_list(
110+
src: torch.Tensor,
111+
batch: Optional[Tensor] = None,
112+
ratio: Optional[float] = None,
113+
random_start: bool = True,
114+
batch_size: Optional[int] = None,
115+
ptr: Optional[List[int]] = None,
116+
):
117+
if ptr is not None:
118+
return torch.ops.torch_cluster.fps_ptr_list(src, ptr,
119+
ratio, random_start)
120+
return fps_ptr_tensor(src, batch, ratio, random_start, batch_size, ptr)
121+
77122

78-
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
123+
fps = fps_ptr_list if hasattr( # noqa
124+
torch.ops.torch_cluster, 'fp_ptr_list') else fps_ptr_tensor

0 commit comments

Comments
 (0)