Skip to content

Commit

Permalink
refactor: tried to improve array_library docstrings, and removed unus…
Browse files Browse the repository at this point in the history
…ed numpy import in smallpebble.py
  • Loading branch information
sradc committed Feb 17, 2022
1 parent 4139aa9 commit 369409b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
63 changes: 42 additions & 21 deletions smallpebble/array_library.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,59 @@
"""This module acts as a proxy, allowing NumPy/CuPy to be switched dynamically.
The default library is NumPy.
To switch to CuPy:
>> import smallpebble as sp
>> import cupy
>> sp.array_library.library = cupy
To switch back to NumPy:
>> import numpy
>> sp.array_library.library = numpy
"""Allows SmallPebble to dynamically switch between NumPy/CuPy,
for its ndarray computations. The default is NumPy.
Note to SmallPebble devs:
Watch out for cases where NumPy and CuPy differ,
e.g. np.add.at is cupy.scatter_add.
"""
from types import ModuleType

import numpy

library = numpy # numpy or cupy


def use(array_library):
"""Set array library to be NumPy or CuPy.
def use(array_library: ModuleType) -> None:
"""Set the array library that SmallPebble will use.
E.g.
>> import cupy
>> import smallpebble as sp
>> sp.use(cupy)
Parameters
----------
array_library : ModuleType
Either NumPy (the SmallPebble default) or CuPy (for GPU acceleration).
To switch back to NumPy:
>> import numpy
>> sp.use(numpy)
Example:
```python
# Switch array library to CuPy.
import cupy
import smallpebble as sp
sp.use(cupy)
# To switch back to NumPy:
import numpy
sp.use(numpy)
```
"""

global library
library = array_library


def __getattr__(name):
def __getattr__(name: str):
"""Make this module act as a proxy, for NumPy/CuPy.
Here's an example:
```python
import smallpebble.array_library as array_library
x = array_library.array([1, 2, 3]) ** 2 # <- a NumPy array
```
In this example, a NumPy array is created,
because `array_library.array` results in this function
being called, which then calls `getattr(numpy, "array")`,
which is NumPy's function for creating arrays.
(The above assumes that `library == numpy`, which is the
default. The same thing would happen but with CuPy arrays,
if `library == cupy`.)
"""
return getattr(library, name)
3 changes: 1 addition & 2 deletions smallpebble/smallpebble.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""
from collections import defaultdict
import math
import numpy
import smallpebble.array_library as np


Expand Down Expand Up @@ -624,7 +623,7 @@ def np_add_at(a, indices, b):
if np.library.__name__ == "numpy":
np.add.at(a, indices, b)
elif np.library.__name__ == "cupy":
np.scatter_add(a, indices, b)
np._cupyx.scatter_add(a, indices, b)
else:
raise ValueError("Expected np.library.__name__ to be `numpy` or `cupy`.")

Expand Down

0 comments on commit 369409b

Please sign in to comment.