1
- from typing import Optional , Union
1
+ from typing import Optional , Union , List
2
2
3
3
import torch
4
4
from torch import Tensor
5
5
6
6
7
7
@torch .jit ._overload # noqa
8
- def fps (src , batch , ratio , random_start , batch_size ): # noqa
9
- # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int]) -> Tensor # noqa
8
+ def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
9
+ # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor] ) -> Tensor # noqa
10
10
pass # pragma: no cover
11
11
12
12
13
13
@torch .jit ._overload # noqa
14
- def fps (src , batch , ratio , random_start , batch_size ): # noqa
15
- # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int]) -> Tensor # noqa
14
+ def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
15
+ # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor] ) -> Tensor # noqa
16
16
pass # pragma: no cover
17
17
18
18
19
- def fps ( # noqa
19
+ @torch .jit ._overload # noqa
20
+ def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
21
+ # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
22
+ pass # pragma: no cover
23
+
24
+
25
+ @torch .jit ._overload # noqa
26
+ def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
27
+ # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
28
+ pass # pragma: no cover
29
+
30
+
31
+ @torch .jit ._overload # noqa
32
+ def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
33
+ # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
34
+ pass # pragma: no cover
35
+
36
+
37
+ def fps_ptr_tensor (
20
38
src : torch .Tensor ,
21
39
batch : Optional [Tensor ] = None ,
22
- ratio : Optional [Union [torch . Tensor , float ]] = None ,
40
+ ratio : Optional [Union [Tensor , float ]] = None ,
23
41
random_start : bool = True ,
24
42
batch_size : Optional [int ] = None ,
43
+ ptr : Optional [Union [Tensor , List [int ]]] = None ,
25
44
):
26
45
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
27
46
Learning on Point Sets in a Metric Space"
@@ -40,6 +59,9 @@ def fps( # noqa
40
59
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
41
60
batch_size (int, optional): The number of examples :math:`B`.
42
61
Automatically calculated if not given. (default: :obj:`None`)
62
+ ptr (LongTensor or list of ints): Ptr vector, which defines nodes
63
+ ranges for consecutive batches, e.g. batch=[0,0,1,1,1,2] translates
64
+ to ptr=[0,2,5,6].
43
65
44
66
:rtype: :class:`LongTensor`
45
67
@@ -52,7 +74,6 @@ def fps( # noqa
52
74
batch = torch.tensor([0, 0, 0, 0])
53
75
index = fps(src, batch, ratio=0.5)
54
76
"""
55
-
56
77
r : Optional [Tensor ] = None
57
78
if ratio is None :
58
79
r = torch .tensor (0.5 , dtype = src .dtype , device = src .device )
@@ -61,6 +82,7 @@ def fps( # noqa
61
82
else :
62
83
r = ratio
63
84
assert r is not None
85
+ assert batch is None or ptr is None
64
86
65
87
if batch is not None :
66
88
assert src .size (0 ) == batch .numel ()
@@ -70,9 +92,33 @@ def fps( # noqa
70
92
deg = src .new_zeros (batch_size , dtype = torch .long )
71
93
deg .scatter_add_ (0 , batch , torch .ones_like (batch ))
72
94
73
- ptr = deg .new_zeros (batch_size + 1 )
74
- torch .cumsum (deg , 0 , out = ptr [1 :])
95
+ p = deg .new_zeros (batch_size + 1 )
96
+ torch .cumsum (deg , 0 , out = p [1 :])
75
97
else :
76
- ptr = torch .tensor ([0 , src .size (0 )], device = src .device )
98
+ if ptr is None :
99
+ p = torch .tensor ([0 , src .size (0 )], device = src .device )
100
+ else :
101
+ if isinstance (ptr , Tensor ):
102
+ p = ptr
103
+ else :
104
+ p = torch .tensor (ptr , device = src .device )
105
+
106
+ return torch .ops .torch_cluster .fps (src , p , r , random_start )
107
+
108
+
109
+ def fps_ptr_list (
110
+ src : torch .Tensor ,
111
+ batch : Optional [Tensor ] = None ,
112
+ ratio : Optional [float ] = None ,
113
+ random_start : bool = True ,
114
+ batch_size : Optional [int ] = None ,
115
+ ptr : Optional [List [int ]] = None ,
116
+ ):
117
+ if ptr is not None :
118
+ return torch .ops .torch_cluster .fps_ptr_list (src , ptr ,
119
+ ratio , random_start )
120
+ return fps_ptr_tensor (src , batch , ratio , random_start , batch_size , ptr )
121
+
77
122
78
- return torch .ops .torch_cluster .fps (src , ptr , r , random_start )
123
+ fps = fps_ptr_list if hasattr ( # noqa
124
+ torch .ops .torch_cluster , 'fp_ptr_list' ) else fps_ptr_tensor
0 commit comments