diff --git a/smallpebble/array_library.py b/smallpebble/array_library.py index 2463a28..078cc6d 100644 --- a/smallpebble/array_library.py +++ b/smallpebble/array_library.py @@ -1,38 +1,59 @@ -"""This module acts as a proxy, allowing NumPy/CuPy to be switched dynamically. -The default library is NumPy. - -To switch to CuPy: ->> import smallpebble as sp ->> import cupy ->> sp.array_library.library = cupy - -To switch back to NumPy: ->> import numpy ->> sp.array_library.library = numpy +"""Allows SmallPebble to dynamically switch between NumPy/CuPy, +for its ndarray computations. The default is NumPy. +Note to SmallPebble devs: Watch out for cases where NumPy and CuPy differ, e.g. np.add.at is cupy.scatter_add. """ +from types import ModuleType + import numpy library = numpy # numpy or cupy -def use(array_library): - """Set array library to be NumPy or CuPy. +def use(array_library: ModuleType) -> None: + """Set the array library that SmallPebble will use. - E.g. - >> import cupy - >> import smallpebble as sp - >> sp.use(cupy) + Parameters + ---------- + array_library : ModuleType + Either NumPy (the SmallPebble default) or CuPy (for GPU acceleration). - To switch back to NumPy: - >> import numpy - >> sp.use(numpy) + Example: + ```python + # Switch array library to CuPy. + import cupy + import smallpebble as sp + sp.use(cupy) + + # To switch back to NumPy: + import numpy + sp.use(numpy) + ``` """ + global library library = array_library -def __getattr__(name): +def __getattr__(name: str): + """Make this module act as a proxy, for NumPy/CuPy. + + Here's an example: + + ```python + import smallpebble.array_library as array_library + + x = array_library.array([1, 2, 3]) ** 2 # <- a NumPy array + ``` + + In this example, a NumPy array is created, + because `array_library.array` results in this function + being called, which then calls `getattr(numpy, "array")`, + which is NumPy's function for creating arrays. + (The above assumes that `library == numpy`, which is the + default. The same thing would happen but with CuPy arrays, + if `library == cupy`.) + """ return getattr(library, name) diff --git a/smallpebble/smallpebble.py b/smallpebble/smallpebble.py index 8bea4ac..97fe2b2 100644 --- a/smallpebble/smallpebble.py +++ b/smallpebble/smallpebble.py @@ -19,7 +19,6 @@ """ from collections import defaultdict import math -import numpy import smallpebble.array_library as np @@ -624,7 +623,7 @@ def np_add_at(a, indices, b): if np.library.__name__ == "numpy": np.add.at(a, indices, b) elif np.library.__name__ == "cupy": - np.scatter_add(a, indices, b) + np._cupyx.scatter_add(a, indices, b) else: raise ValueError("Expected np.library.__name__ to be `numpy` or `cupy`.")