-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: tried to improve array_library docstrings, and removed unus…
…ed numpy import in smallpebble.py
- Loading branch information
Showing
2 changed files
with
43 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,59 @@ | ||
"""This module acts as a proxy, allowing NumPy/CuPy to be switched dynamically. | ||
The default library is NumPy. | ||
To switch to CuPy: | ||
>> import smallpebble as sp | ||
>> import cupy | ||
>> sp.array_library.library = cupy | ||
To switch back to NumPy: | ||
>> import numpy | ||
>> sp.array_library.library = numpy | ||
"""Allows SmallPebble to dynamically switch between NumPy/CuPy, | ||
for its ndarray computations. The default is NumPy. | ||
Note to SmallPebble devs: | ||
Watch out for cases where NumPy and CuPy differ, | ||
e.g. np.add.at is cupy.scatter_add. | ||
""" | ||
from types import ModuleType | ||
|
||
import numpy | ||
|
||
library = numpy # numpy or cupy | ||
|
||
|
||
def use(array_library): | ||
"""Set array library to be NumPy or CuPy. | ||
def use(array_library: ModuleType) -> None: | ||
"""Set the array library that SmallPebble will use. | ||
E.g. | ||
>> import cupy | ||
>> import smallpebble as sp | ||
>> sp.use(cupy) | ||
Parameters | ||
---------- | ||
array_library : ModuleType | ||
Either NumPy (the SmallPebble default) or CuPy (for GPU acceleration). | ||
To switch back to NumPy: | ||
>> import numpy | ||
>> sp.use(numpy) | ||
Example: | ||
```python | ||
# Switch array library to CuPy. | ||
import cupy | ||
import smallpebble as sp | ||
sp.use(cupy) | ||
# To switch back to NumPy: | ||
import numpy | ||
sp.use(numpy) | ||
``` | ||
""" | ||
|
||
global library | ||
library = array_library | ||
|
||
|
||
def __getattr__(name): | ||
def __getattr__(name: str): | ||
"""Make this module act as a proxy, for NumPy/CuPy. | ||
Here's an example: | ||
```python | ||
import smallpebble.array_library as array_library | ||
x = array_library.array([1, 2, 3]) ** 2 # <- a NumPy array | ||
``` | ||
In this example, a NumPy array is created, | ||
because `array_library.array` results in this function | ||
being called, which then calls `getattr(numpy, "array")`, | ||
which is NumPy's function for creating arrays. | ||
(The above assumes that `library == numpy`, which is the | ||
default. The same thing would happen but with CuPy arrays, | ||
if `library == cupy`.) | ||
""" | ||
return getattr(library, name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters