Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions benchmarks/benchmarks/neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
from MDAnalysis.lib.pkdtree import PeriodicKDTree
from MDAnalysis.lib.distances import capped_distance
from scipy.spatial import cKDTree


class NeighborsBench:
"""Benchmarks for neighbor searching functions."""

params = ([100, 1000, 10000, 100000], [20, 30, 36, 42, 48, 50, 60])
param_names = ["number_of_atoms", "cutoff"]

def setup(self, number_of_atoms, cutoff):
"""Setup called before each benchmark with each parameter combination."""
self.box = np.array(
[170.0, 70.0, 120.0, 90.0, 90.0, 90.0], dtype=np.float32
)
self.positions = (
np.random.rand(number_of_atoms, 3) * self.box[:3]
).astype(np.float32)
self.centre = (self.box[:3] / 2.0).reshape(1, 3)
self.cutoff = cutoff

self.scipy_tree = cKDTree(self.positions, boxsize=self.box[:3])
self.mda_tree = PeriodicKDTree(box=self.box)
self.mda_tree.set_coords(self.positions, cutoff=self.cutoff)

def time_mda_tree_search(self, number_of_atoms, cutoff):
"""Benchmark just the search operation on pre-built tree."""
self.mda_tree.search(self.centre, self.cutoff)

def time_scipy_tree_query(self, number_of_atoms, cutoff):
"""Benchmark just the query operation on pre-built tree."""
self.scipy_tree.query_ball_point(self.centre, self.cutoff)

def time_mda_PKDtree_with_setup(self, number_of_atoms, cutoff):
"""Benchmark tree construction + search."""
tree = PeriodicKDTree(box=self.box)
tree.set_coords(self.positions, cutoff=self.cutoff)
tree.search(self.centre, self.cutoff)

def time_scipy_cKDTree_with_setup(self, number_of_atoms, cutoff):
"""Benchmark tree construction + query."""
tree = cKDTree(self.positions, boxsize=self.box[:3])
tree.query_ball_point(self.centre, self.cutoff)

def time_capped_distance_array(self, number_of_atoms, cutoff):
"""Benchmark capped distance calculation."""
capped_distance(
self.centre, self.positions, max_cutoff=self.cutoff, box=self.box
)
244 changes: 242 additions & 2 deletions package/MDAnalysis/lib/pkdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@

from MDAnalysis.lib.distances import apply_PBC
import numpy.typing as npt
from typing import Optional, ClassVar
from typing import Optional, ClassVar, Union, Any

__all__ = ["PeriodicKDTree"]


class PeriodicKDTree(object):
class AugmentedPKDTree(object):
"""Wrapper around :class:`scipy.spatial.cKDTree`

Creates an object which can handle periodic as well as
Expand Down Expand Up @@ -336,3 +336,243 @@ class initialization
if pairs.size > 0:
pairs = unique_rows(pairs)
return pairs


class PeriodicKDTree(object):

def __init__(
self, box: Optional[npt.ArrayLike] = None, leafsize: int = 10
) -> None:
self.leafsize = leafsize
self.dim = 3
self.box = box
self._built = False

self.cutoff: Optional[float] = None
self.mapping: Optional[npt.NDArray] = None
self._tree: Optional[Union[AugmentedPKDTree, cKDTree]] = None

_use_augmented = False
if box is not None:
box_array = np.asarray(box, dtype=np.float32)
if box_array.shape == (6,):
if not np.allclose(box_array[3:], 90.0):
_use_augmented = True
else:
_use_augmented = True

self._use_augmented = _use_augmented

if self._use_augmented:
self._tree = AugmentedPKDTree(box=self.box, leafsize=leafsize)
else:
self._tree = None
if box is not None:
self.box = np.asarray(box, dtype=np.float32)
self._is_ortho = True

@property
def pbc(self):
"""Flag to indicate the presence of periodic boundaries.

- ``True`` if PBC are taken into account
- ``False`` if no unitcell dimension is available.

This is a managed attribute and can only be read.
"""
return self.box is not None

def set_coords(
self, coords: npt.ArrayLike, cutoff: Optional[float] = None
) -> None:
"""Constructs KDTree from the coordinates

Parameters
----------
coords: array_like
Coordinate array of shape ``(N, 3)`` for N atoms.
cutoff: float
Specified cutoff distance for searches.
Required for periodic calculations.
"""
if self._use_augmented:
assert self._tree is not None
self._tree.set_coords(coords, cutoff)
self._built = True
self.cutoff = cutoff
else:
coords = np.asarray(coords, dtype=np.float32)
self.cutoff = cutoff

if self.box is None:
if cutoff is not None:
raise RuntimeError(
"Donot provide cutoff distance for non PBC aware calculations"
)
self.coords = coords
self._tree = cKDTree(self.coords, leafsize=self.leafsize)
else:
if cutoff is None:
raise RuntimeError(
"Provide a cutoff distance with tree.set_coords(...)"
)
self.coords = apply_PBC(coords, self.box)
box_array = np.asarray(self.box, dtype=np.float32)
self._tree = cKDTree(
self.coords, leafsize=self.leafsize, boxsize=box_array[:3]
)

self._built = True

def search(self, centers: npt.ArrayLike, radius: float) -> npt.NDArray:
"""Search all points within radius from centers and their periodic images.

Parameters
----------
centers: array_like (N,3)
coordinate array to search for neighbors
radius: float
maximum distance to search for neighbors.
"""
if not self._built:
raise RuntimeError("Unbuilt tree. Run tree.set_coords(...)")

if self._use_augmented:
assert self._tree is not None
return self._tree.search(centers, radius)

centers = np.asarray(centers, dtype=np.float32)
if centers.shape == (self.dim,):
centers = centers.reshape((1, self.dim))

if self.pbc:
if self.cutoff is None:
raise ValueError(
"Cutoff needs to be provided when working with PBC."
)
if self.cutoff < radius:
raise RuntimeError(
"Set cutoff greater or equal to the radius."
)
wrapped_centers = apply_PBC(centers, self.box)
assert isinstance(self._tree, cKDTree)
indices = list(
self._tree.query_ball_point(wrapped_centers, radius)
)
else:
assert isinstance(self._tree, cKDTree)
indices = list(self._tree.query_ball_point(centers, radius))

self._indices = np.array(
list(itertools.chain.from_iterable(indices)), dtype=np.intp
)

if self._indices.size > 0:
self._indices = np.asarray(unique_int_1d(self._indices))
return self._indices

def get_indices(self) -> npt.NDArray:
"""Return the neighbors from the last query.

Returns
------
indices : NDArray
neighbors for the last query points and search radius
"""
return self._indices

def search_pairs(self, radius: float) -> npt.NDArray:
"""Search all the pairs within a specified radius

Parameters
----------
radius : float
Maximum distance between pairs of coordinates

Returns
-------
pairs : array
Indices of all the pairs which are within the specified radius
"""
if not self._built:
raise RuntimeError("Unbuilt Tree. Run tree.set_coords(...)")

if self._use_augmented:
assert self._tree is not None
return self._tree.search_pairs(radius)

if self.pbc:
if self.cutoff is None:
raise ValueError(
"Cutoff needs to be provided when working with PBC."
)
if self.cutoff < radius:
raise RuntimeError(
"Set cutoff greater or equal to the radius."
)

assert isinstance(self._tree, cKDTree)
pairs = np.array(list(self._tree.query_pairs(radius)), dtype=np.intp)

if pairs.size > 0:
pairs = np.sort(pairs, axis=1)
pairs = unique_rows(pairs)
return pairs

def search_tree(self, centers: npt.ArrayLike, radius: float) -> np.ndarray:
"""
Searches all the pairs within `radius` between `centers`
and ``coords``

``coords`` are the already initialized coordinates in the tree
during :meth:`set_coords`.

Parameters
----------
centers: array_like (N,3)
coordinate array to search for neighbors
radius: float
maximum distance to search for neighbors.

Returns
-------
pairs : array
all the pairs between ``coords`` and ``centers``
"""
if not self._built:
raise RuntimeError("Unbuilt tree. Run tree.set_coords(...)")

if self._use_augmented:
assert self._tree is not None
return self._tree.search_tree(centers, radius)

centers = np.asarray(centers, dtype=np.float32)
if centers.shape == (self.dim,):
centers = centers.reshape((1, self.dim))

if self.pbc:
if self.cutoff is None:
raise ValueError(
"Cutoff needs to be provided when working with PBC."
)
if self.cutoff < radius:
raise RuntimeError(
"Set cutoff greater or equal to the radius."
)
wrapped_centers = apply_PBC(centers, self.box)
box_array = np.asarray(self.box, dtype=np.float32)
other_tree = cKDTree(
wrapped_centers, leafsize=self.leafsize, boxsize=box_array[:3]
)
else:
other_tree = cKDTree(centers, leafsize=self.leafsize)

pairs_list = other_tree.query_ball_tree(self._tree, radius)
pairs = np.array(
[[i, j] for i, lst in enumerate(pairs_list) for j in lst],
dtype=np.intp,
)

if pairs.size > 0:
pairs = unique_rows(pairs)
return pairs
Loading
Loading