-
Notifications
You must be signed in to change notification settings - Fork 0
Add a scipy.sparse numba extension #73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
aace7f4
wip
flying-sheep 7b7dac2
lint
flying-sheep da2fd7f
fix import
flying-sheep 050edb7
try without casts
flying-sheep b0c6c2e
Merge branch 'main' into pa/nbext
flying-sheep c2a473d
Revert "try without casts"
flying-sheep c11550f
WIP
flying-sheep ad2d5a3
WIP
flying-sheep 1bcedcb
WIP
flying-sheep c9e197d
bigger bench
flying-sheep 0c36092
Merge branch 'main' into pa/nbext
flying-sheep 59d511a
WIP
flying-sheep 65b7b31
fix typing
flying-sheep d85d644
lints
flying-sheep 484c322
skip numba import
flying-sheep d198966
coverage
flying-sheep 9ac3887
fix ndim
flying-sheep 8d62ad3
urgh
flying-sheep 806fc0b
more cov
flying-sheep 774bfa4
Update src/fast_array_utils/_plugins/numba_sparse.py
flying-sheep e018ecb
Document extending
flying-sheep b451be6
Some type fixes
flying-sheep 394485d
tests
flying-sheep c963b3c
clarify tests
flying-sheep 50d9de0
adapt tests
flying-sheep 38f9d22
fix tests
flying-sheep d30b2f2
last docs
flying-sheep e29172b
allow random int arrays
flying-sheep a403639
fix docs
flying-sheep 59b9fc3
coverage
flying-sheep File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# SPDX-License-Identifier: MPL-2.0 | ||
from __future__ import annotations | ||
|
||
|
||
__all__ = ["patch_dask", "register_numba_sparse"] | ||
|
||
|
||
def patch_dask() -> None: | ||
r"""Patch Dask Arrays so it supports `scipy.sparse.sparray`\ s.""" | ||
try: | ||
from .dask import patch | ||
except ImportError: | ||
pass | ||
else: | ||
patch() | ||
|
||
|
||
def register_numba_sparse() -> None: | ||
r"""Register `scipy.sparse.sp{matrix,array}`\ s with Numba. | ||
|
||
This makes it cleaner to write numba functions operating on these types. | ||
""" | ||
try: | ||
from .numba_sparse import register | ||
except ImportError: | ||
pass | ||
else: | ||
register() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
# SPDX-License-Identifier: MPL-2.0 | ||
"""Numba support for sparse arrays and matrices.""" | ||
|
||
# taken from https://github.com/numba/numba-scipy/blob/release0.4/numba_scipy/sparse.py | ||
# See https://numba.pydata.org/numba-doc/dev/extending/ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, cast | ||
|
||
import numba.core.types as nbtypes | ||
import numpy as np | ||
from numba.core import cgutils | ||
from numba.core.imputils import impl_ret_borrowed | ||
from numba.extending import ( | ||
NativeValue, | ||
box, | ||
intrinsic, | ||
make_attribute_wrapper, | ||
models, | ||
overload, | ||
overload_attribute, | ||
overload_method, | ||
register_model, | ||
typeof_impl, | ||
unbox, | ||
) | ||
from scipy import sparse | ||
|
||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable, Mapping, Sequence | ||
from typing import Any, ClassVar, Literal | ||
|
||
from llvmlite.ir import IRBuilder, Value | ||
from numba.core.base import BaseContext | ||
from numba.core.datamodel.manager import DataModelManager | ||
from numba.core.extending import BoxContext, TypingContext, UnboxContext | ||
from numba.core.typing.templates import Signature | ||
from numba.core.typing.typeof import _TypeofContext | ||
from numpy.typing import NDArray | ||
|
||
from fast_array_utils.types import CSBase | ||
|
||
|
||
class CSType(nbtypes.Type): | ||
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`. | ||
|
||
This is an abstract base class for the actually used, registered types in `TYPES` below. | ||
It collects information about the type (e.g. field dtypes) for later use in the data model. | ||
""" | ||
|
||
name: ClassVar[Literal["csr_matrix", "csc_matrix", "csr_array", "csc_array"]] | ||
cls: ClassVar[type[CSBase]] | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def instance_class( | ||
cls, | ||
data: NDArray[np.number[Any]], | ||
indices: NDArray[np.integer[Any]], | ||
indptr: NDArray[np.integer[Any]], | ||
shape: tuple[int, int], # actually tuple[int, ...] for sparray subclasses | ||
) -> CSBase: | ||
return cls.cls((data, indices, indptr), shape, copy=False) | ||
|
||
def __init__(self, ndim: int, *, dtype: nbtypes.Type, dtype_ind: nbtypes.Type) -> None: | ||
self.dtype = nbtypes.DType(dtype) | ||
self.dtype_ind = nbtypes.DType(dtype_ind) | ||
self.data = nbtypes.Array(dtype, 1, "A") | ||
self.indices = nbtypes.Array(dtype_ind, 1, "A") | ||
self.indptr = nbtypes.Array(dtype_ind, 1, "A") | ||
self.shape = nbtypes.UniTuple(nbtypes.intp, ndim) | ||
super().__init__(self.name) | ||
|
||
@property | ||
def key(self) -> tuple[str | nbtypes.Type, ...]: | ||
return (self.name, self.dtype, self.dtype_ind) | ||
|
||
|
||
# make data model attributes available in numba functions | ||
for attr in ["data", "indices", "indptr", "shape"]: | ||
make_attribute_wrapper(CSType, attr, attr) | ||
|
||
|
||
def make_typeof_fn(typ: type[CSType]) -> Callable[[CSBase, _TypeofContext], CSType]: | ||
"""Create a `typeof` function that maps a scipy matrix/array type to a numba `Type`.""" | ||
|
||
def typeof(val: CSBase, c: _TypeofContext) -> CSType: | ||
if val.indptr.dtype != val.indices.dtype: # pragma: no cover | ||
msg = "indptr and indices must have the same dtype" | ||
raise TypeError(msg) | ||
data = cast("nbtypes.Array", typeof_impl(val.data, c)) | ||
indptr = cast("nbtypes.Array", typeof_impl(val.indptr, c)) | ||
return typ(val.ndim, dtype=data.dtype, dtype_ind=indptr.dtype) | ||
|
||
return typeof | ||
|
||
|
||
if TYPE_CHECKING: | ||
_CSModelBase = models.StructModel[CSType] | ||
else: | ||
_CSModelBase = models.StructModel | ||
|
||
|
||
class CSModel(_CSModelBase): | ||
"""Numba data model for compressed sparse matrices. | ||
|
||
This is the class that is used by numba to lower the array types. | ||
""" | ||
|
||
def __init__(self, dmm: DataModelManager, fe_type: CSType) -> None: | ||
members = [ | ||
("data", fe_type.data), | ||
("indices", fe_type.indices), | ||
("indptr", fe_type.indptr), | ||
("shape", fe_type.shape), | ||
] | ||
super().__init__(dmm, fe_type, members) | ||
|
||
|
||
# create all the actual types and data models | ||
CLASSES: Sequence[type[CSBase]] = [ | ||
sparse.csr_matrix, | ||
sparse.csc_matrix, | ||
sparse.csr_array, | ||
sparse.csc_array, | ||
] | ||
TYPES: Sequence[type[CSType]] = [ | ||
type(f"{cls.__name__}Type", (CSType,), {"cls": cls, "name": cls.__name__}) for cls in CLASSES | ||
] | ||
TYPEOF_FUNCS: Mapping[type[CSBase], Callable[[CSBase, _TypeofContext], CSType]] = { | ||
typ.cls: make_typeof_fn(typ) for typ in TYPES | ||
} | ||
MODELS: Mapping[type[CSType], type[CSModel]] = { | ||
typ: type(f"{typ.cls.__name__}Model", (CSModel,), {}) for typ in TYPES | ||
} | ||
|
||
|
||
def unbox_matrix(typ: CSType, obj: Value, c: UnboxContext) -> NativeValue: | ||
"""Convert a Python cs{rc}_{matrix,array} to a Numba value.""" | ||
struct_proxy_cls = cgutils.create_struct_proxy(typ) | ||
struct_ptr = struct_proxy_cls(c.context, c.builder) | ||
|
||
data = c.pyapi.object_getattr_string(obj, "data") | ||
indices = c.pyapi.object_getattr_string(obj, "indices") | ||
indptr = c.pyapi.object_getattr_string(obj, "indptr") | ||
shape = c.pyapi.object_getattr_string(obj, "shape") | ||
|
||
struct_ptr.data = c.unbox(typ.data, data).value | ||
struct_ptr.indices = c.unbox(typ.indices, indices).value | ||
struct_ptr.indptr = c.unbox(typ.indptr, indptr).value | ||
struct_ptr.shape = c.unbox(typ.shape, shape).value | ||
|
||
c.pyapi.decref(data) | ||
c.pyapi.decref(indices) | ||
c.pyapi.decref(indptr) | ||
c.pyapi.decref(shape) | ||
|
||
is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) | ||
is_error = c.builder.load(is_error_ptr) | ||
|
||
return NativeValue(struct_ptr._getvalue(), is_error=is_error) # noqa: SLF001 | ||
|
||
|
||
def box_matrix(typ: CSType, val: NativeValue, c: BoxContext) -> Value: | ||
"""Convert numba value into a Python cs{rc}_{matrix,array}.""" | ||
struct_proxy_cls = cgutils.create_struct_proxy(typ) | ||
struct_ptr = struct_proxy_cls(c.context, c.builder, value=val) | ||
|
||
data_obj = c.box(typ.data, struct_ptr.data) | ||
indices_obj = c.box(typ.indices, struct_ptr.indices) | ||
indptr_obj = c.box(typ.indptr, struct_ptr.indptr) | ||
shape_obj = c.box(typ.shape, struct_ptr.shape) | ||
|
||
c.pyapi.incref(data_obj) | ||
c.pyapi.incref(indices_obj) | ||
c.pyapi.incref(indptr_obj) | ||
c.pyapi.incref(shape_obj) | ||
|
||
cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class)) | ||
obj = c.pyapi.call_function_objargs(cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj)) | ||
|
||
c.pyapi.decref(data_obj) | ||
c.pyapi.decref(indices_obj) | ||
c.pyapi.decref(indptr_obj) | ||
c.pyapi.decref(shape_obj) | ||
|
||
return obj | ||
|
||
|
||
# See https://numba.readthedocs.io/en/stable/extending/overloading-guide.html | ||
@overload(np.shape) | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def overload_sparse_shape(x: CSType) -> None | Callable[[CSType], nbtypes.UniTuple]: | ||
if not isinstance(x, CSType): # pragma: no cover | ||
return None | ||
|
||
# nopython code: | ||
def shape(x: CSType) -> nbtypes.UniTuple: # pragma: no cover | ||
return x.shape | ||
|
||
return shape | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@overload_attribute(CSType, "ndim") | ||
def overload_sparse_ndim(inst: CSType) -> None | Callable[[CSType], int]: | ||
if not isinstance(inst, CSType): # pragma: no cover | ||
return None | ||
|
||
# nopython code: | ||
def ndim(inst: CSType) -> int: # pragma: no cover | ||
return len(inst.shape) | ||
|
||
return ndim | ||
|
||
|
||
@intrinsic | ||
def _sparse_copy( | ||
typingctx: TypingContext, # noqa: ARG001 | ||
inst: CSType, | ||
data: nbtypes.Array, # noqa: ARG001 | ||
indices: nbtypes.Array, # noqa: ARG001 | ||
indptr: nbtypes.Array, # noqa: ARG001 | ||
shape: nbtypes.UniTuple, # noqa: ARG001 | ||
) -> tuple[Signature, Callable[..., NativeValue]]: | ||
def _construct( | ||
context: BaseContext, | ||
builder: IRBuilder, | ||
sig: Signature, | ||
args: tuple[Value, Value, Value, Value, Value], | ||
) -> NativeValue: | ||
struct_proxy_cls = cgutils.create_struct_proxy(sig.return_type) | ||
struct = struct_proxy_cls(context, builder) | ||
_, data, indices, indptr, shape = args | ||
struct.data = data | ||
struct.indices = indices | ||
struct.indptr = indptr | ||
struct.shape = shape | ||
return impl_ret_borrowed( | ||
context, | ||
builder, | ||
sig.return_type, | ||
struct._getvalue(), # noqa: SLF001 | ||
) | ||
|
||
sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape) | ||
|
||
return sig, _construct | ||
|
||
|
||
@overload_method(CSType, "copy") | ||
def overload_sparse_copy(inst: CSType) -> None | Callable[[CSType], CSType]: | ||
if not isinstance(inst, CSType): # pragma: no cover | ||
return None | ||
|
||
# nopython code: | ||
def copy(inst: CSType) -> CSType: # pragma: no cover | ||
return _sparse_copy( | ||
inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape | ||
) # type: ignore[return-value] | ||
|
||
return copy | ||
|
||
|
||
def register() -> None: | ||
"""Register the numba types, data models, and mappings between them and the Python types.""" | ||
for cls, func in TYPEOF_FUNCS.items(): | ||
typeof_impl.register(cls, func) | ||
for typ, model in MODELS.items(): | ||
register_model(typ)(model) | ||
unbox(typ)(unbox_matrix) | ||
box(typ)(box_matrix) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.