From d149d431a6f24771e8ae886454ac23ba7833941c Mon Sep 17 00:00:00 2001 From: Piotr Chmiel Date: Wed, 26 Apr 2023 13:40:40 +0100 Subject: [PATCH] Add batch_size argument for fps, knn, radius functions. It can be used to avoid additional calculations if a user is using fixed-size batch. --- torch_cluster/fps.py | 20 ++++++++++---------- torch_cluster/knn.py | 31 +++++++++++++++++++++---------- torch_cluster/radius.py | 32 ++++++++++++++++++++++---------- 3 files changed, 53 insertions(+), 30 deletions(-) diff --git a/torch_cluster/fps.py b/torch_cluster/fps.py index 7901dd5..5f7dca9 100644 --- a/torch_cluster/fps.py +++ b/torch_cluster/fps.py @@ -5,18 +5,14 @@ @torch.jit._overload # noqa -def fps(src, batch=None, ratio=None, random_start=True): # noqa - # type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor +def fps(src, batch=None, ratio=None, random_start=True, batch_size=None): # noqa + # type: (Tensor, Optional[Tensor], Optional[float], bool + # Optional[int]) -> Tensor pass # pragma: no cover -@torch.jit._overload # noqa -def fps(src, batch=None, ratio=None, random_start=True): # noqa - # type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor - pass # pragma: no cover - - -def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa +def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True, # noqa + batch_size=None): # noqa r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" `_ paper, which iteratively samples the @@ -32,6 +28,9 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa (default: :obj:`0.5`) random_start (bool, optional): If set to :obj:`False`, use the first node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. + (default: :obj:`None`) :rtype: :class:`LongTensor` @@ -57,7 +56,8 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa if batch is not None: assert src.size(0) == batch.numel() - batch_size = int(batch.max()) + 1 + if batch_size is None: + batch_size = int(batch.max()) + 1 deg = src.new_zeros(batch_size, dtype=torch.long) deg.scatter_add_(0, batch, torch.ones_like(batch)) diff --git a/torch_cluster/knn.py b/torch_cluster/knn.py index b981c46..938f800 100644 --- a/torch_cluster/knn.py +++ b/torch_cluster/knn.py @@ -7,7 +7,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, batch_x: Optional[torch.Tensor] = None, batch_y: Optional[torch.Tensor] = None, cosine: bool = False, - num_workers: int = 1) -> torch.Tensor: + num_workers: int = 1, + batch_size: Optional[int] = None) -> torch.Tensor: r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in :obj:`x`. @@ -31,6 +32,9 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. + (default: :obj:`None`) :rtype: :class:`LongTensor` @@ -52,13 +56,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, y = y.view(-1, 1) if y.dim() == 1 else y x, y = x.contiguous(), y.contiguous() - batch_size = 1 - if batch_x is not None: - assert x.size(0) == batch_x.numel() - batch_size = int(batch_x.max()) + 1 - if batch_y is not None: - assert y.size(0) == batch_y.numel() - batch_size = max(batch_size, int(batch_y.max()) + 1) + if batch_size is None: + batch_size = 1 + if batch_x is not None: + assert x.size(0) == batch_x.numel() + batch_size = int(batch_x.max()) + 1 + if batch_y is not None: + assert y.size(0) == batch_y.numel() + batch_size = max(batch_size, int(batch_y.max()) + 1) + + assert batch_size > 0 ptr_x: Optional[torch.Tensor] = None ptr_y: Optional[torch.Tensor] = None @@ -76,7 +83,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, @torch.jit.script def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, loop: bool = False, flow: str = 'source_to_target', - cosine: bool = False, num_workers: int = 1) -> torch.Tensor: + cosine: bool = False, num_workers: int = 1, + batch_size: Optional[int] = None) -> torch.Tensor: r"""Computes graph edges to the nearest :obj:`k` points. Args: @@ -98,6 +106,9 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. + (default: :obj:`None`) :rtype: :class:`LongTensor` @@ -113,7 +124,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, assert flow in ['source_to_target', 'target_to_source'] edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine, - num_workers) + num_workers, batch_size) if flow == 'source_to_target': row, col = edge_index[1], edge_index[0] diff --git a/torch_cluster/radius.py b/torch_cluster/radius.py index fd73b75..12aeee2 100644 --- a/torch_cluster/radius.py +++ b/torch_cluster/radius.py @@ -7,7 +7,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, batch_x: Optional[torch.Tensor] = None, batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32, - num_workers: int = 1) -> torch.Tensor: + num_workers: int = 1, + batch_size: Optional[int] = None) -> torch.Tensor: r"""Finds for each element in :obj:`y` all points in :obj:`x` within distance :obj:`r`. @@ -33,6 +34,9 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. + (default: :obj:`None`) .. code-block:: python @@ -52,16 +56,20 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, y = y.view(-1, 1) if y.dim() == 1 else y x, y = x.contiguous(), y.contiguous() - batch_size = 1 - if batch_x is not None: - assert x.size(0) == batch_x.numel() - batch_size = int(batch_x.max()) + 1 - if batch_y is not None: - assert y.size(0) == batch_y.numel() - batch_size = max(batch_size, int(batch_y.max()) + 1) + if batch_size is None: + batch_size = 1 + if batch_x is not None: + assert x.size(0) == batch_x.numel() + batch_size = int(batch_x.max()) + 1 + if batch_y is not None: + assert y.size(0) == batch_y.numel() + batch_size = max(batch_size, int(batch_y.max()) + 1) + + assert batch_size > 0 ptr_x: Optional[torch.Tensor] = None ptr_y: Optional[torch.Tensor] = None + if batch_size > 1: assert batch_x is not None assert batch_y is not None @@ -77,7 +85,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, def radius_graph(x: torch.Tensor, r: float, batch: Optional[torch.Tensor] = None, loop: bool = False, max_num_neighbors: int = 32, flow: str = 'source_to_target', - num_workers: int = 1) -> torch.Tensor: + num_workers: int = 1, + batch_size: Optional[int] = None) -> torch.Tensor: r"""Computes graph edges to all points within a given distance. Args: @@ -101,6 +110,9 @@ def radius_graph(x: torch.Tensor, r: float, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. + (default: :obj:`None`) :rtype: :class:`LongTensor` @@ -117,7 +129,7 @@ def radius_graph(x: torch.Tensor, r: float, assert flow in ['source_to_target', 'target_to_source'] edge_index = radius(x, x, r, batch, batch, max_num_neighbors if loop else max_num_neighbors + 1, - num_workers) + num_workers, batch_size) if flow == 'source_to_target': row, col = edge_index[1], edge_index[0] else: