Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions docs/source/guides/array_api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
Array API Compatibility
=======================

ezmsg-learn uses the `Array API standard <https://data-apis.org/array-api/latest/>`_
to allow processors to operate on arrays from different backends — NumPy, CuPy,
PyTorch, and others — without code changes.

.. contents:: On this page
:local:
:depth: 2


How It Works
------------

Modules that support the Array API derive the array namespace from their input
data using ``array_api_compat.get_namespace()``:

.. code-block:: python

from array_api_compat import get_namespace

def process(self, data):
xp = get_namespace(data) # numpy, cupy, torch, etc.
result = xp.linalg.inv(data) # dispatches to the right backend
return result

This means that if you pass a CuPy array, all computation stays on the GPU.
If you pass a NumPy array, it behaves exactly as before.

Helper utilities from ``ezmsg.sigproc.util.array`` handle device placement
and creation functions portably:

- ``array_device(x)`` — returns the device of an array, or ``None``
- ``xp_create(fn, *args, dtype=None, device=None)`` — calls creation
functions (``zeros``, ``eye``) with optional device
- ``xp_asarray(xp, obj, dtype=None, device=None)`` — portable ``asarray``


Module Compatibility
--------------------

The table below summarises the Array API status of each module.

Fully compatible
^^^^^^^^^^^^^^^^

These modules perform all computation in the source array namespace.

.. list-table::
:header-rows: 1
:widths: 35 65

* - Module
- Notes
* - ``process.ssr``
- LRR / self-supervised regression. Full Array API.
* - ``model.cca``
- Incremental CCA. Replaced ``scipy.linalg.sqrtm`` with an
eigendecomposition-based inverse square root using only Array API ops.
* - ``process.rnn``
- PyTorch-native; operates on ``torch.Tensor`` throughout.

Mostly compatible (with NumPy boundaries)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

These modules use the Array API for data manipulation but fall back to NumPy
at specific points where a dependency requires it.

.. list-table::
:header-rows: 1
:widths: 25 35 40

* - Module
- NumPy boundary
- Reason
* - ``model.refit_kalman``
- ``_compute_gain()``
- ``scipy.linalg.solve_discrete_are`` has no Array API equivalent.
Matrices are converted to NumPy for the DARE solver, then converted back.
* - ``model.refit_kalman``
- ``refit()`` mutation loop
- Per-sample velocity remapping uses ``np.linalg.norm`` on small vectors
and scalar element assignment.
* - ``process.refit_kalman``
- Inherits boundaries from model
- State init and output arrays use the source namespace.
* - ``process.slda``
- ``predict_proba``
- sklearn ``LinearDiscriminantAnalysis`` requires NumPy input.
* - ``process.adaptive_linear_regressor``
- ``partial_fit`` / ``predict``
- sklearn and river models require NumPy / pandas input.
* - ``dim_reduce.adaptive_decomp``
- ``partial_fit`` / ``transform``
- sklearn ``IncrementalPCA`` and ``MiniBatchNMF`` require NumPy input.

Not converted
^^^^^^^^^^^^^

These modules use NumPy directly. Conversion would provide little benefit
because the underlying estimator is the bottleneck.

.. list-table::
:header-rows: 1
:widths: 25 75

* - Module
- Reason
* - ``process.linear_regressor``
- Thin wrapper around sklearn ``LinearModel.predict``.
Could be made compatible if sklearn's ``array_api_dispatch`` is enabled
(see below).
* - ``process.sgd``
- sklearn ``SGDClassifier`` has no Array API support.
* - ``process.sklearn``
- Generic wrapper for arbitrary models; cannot assume Array API support.
* - ``dim_reduce.incremental_decomp``
- Delegates to ``adaptive_decomp``; trivial numpy usage (``np.prod`` on
Python tuples).


sklearn Array API Dispatch
--------------------------

scikit-learn 1.8+ has experimental support for Array API dispatch on a subset
of estimators. Two estimators used in ezmsg-learn are on the supported list:

.. list-table::
:header-rows: 1
:widths: 30 30 40

* - Estimator
- Used in
- Constraint
* - ``LinearDiscriminantAnalysis``
- ``process.slda``
- Requires ``solver="svd"`` (the ``"lsqr"`` solver with ``shrinkage``
is not supported)
* - ``Ridge``
- ``process.linear_regressor``
- Requires ``solver="svd"``

To use dispatch, enable it before creating the estimator:

.. code-block:: python

from sklearn import set_config
set_config(array_api_dispatch=True)

.. warning::

- ``array_api_dispatch`` is marked **experimental** in sklearn.
- Solver constraints (``solver="svd"``) may produce slightly different
numerical results compared to other solvers.
- Enabling dispatch globally may affect other sklearn estimators in the
same process.
- ezmsg-learn does **not** enable dispatch by default.

Estimators that do **not** support Array API dispatch:

- ``IncrementalPCA``, ``MiniBatchNMF`` — only batch ``PCA`` is supported
- ``SGDClassifier``, ``SGDRegressor``, ``PassiveAggressiveRegressor``
- All river models


Writing Array API Compatible Code
----------------------------------

When adding or modifying processors in ezmsg-learn, follow these patterns.

Deriving the namespace
^^^^^^^^^^^^^^^^^^^^^^

Always derive ``xp`` from the input data, not from a hardcoded ``numpy``:

.. code-block:: python

from array_api_compat import get_namespace
from ezmsg.sigproc.util.array import array_device, xp_create

def _process(self, message):
xp = get_namespace(message.data)
dev = array_device(message.data)

Transposing matrices
^^^^^^^^^^^^^^^^^^^^

The Array API does not support ``.T``. Use ``xp.linalg.matrix_transpose()``:

.. code-block:: python

# Before (numpy-only)
result = A.T @ B

# After (Array API)
_mT = xp.linalg.matrix_transpose
result = _mT(A) @ B

Creating arrays
^^^^^^^^^^^^^^^

Use ``xp_create`` to handle device placement portably:

.. code-block:: python

# Before
I = np.eye(n)
z = np.zeros((m, n), dtype=np.float64)

# After
I = xp_create(xp.eye, n, device=dev)
z = xp_create(xp.zeros, (m, n), dtype=xp.float64, device=dev)

Handling sklearn boundaries
^^^^^^^^^^^^^^^^^^^^^^^^^^^

When calling into sklearn (or other NumPy-only libraries), convert at the
boundary and convert back:

.. code-block:: python

from array_api_compat import is_numpy_array

# Convert to numpy for sklearn
X_np = np.asarray(X) if not is_numpy_array(X) else X
result_np = estimator.predict(X_np)

# Convert back to source namespace
result = xp.asarray(result_np) if not is_numpy_array(X) else result_np

Checking for NaN
^^^^^^^^^^^^^^^^

Use ``xp.isnan`` instead of ``np.isnan``:

.. code-block:: python

if xp.any(xp.isnan(message.data)):
return

Norms
^^^^^

Use ``xp.linalg.matrix_norm`` (Frobenius by default) instead of
``np.linalg.norm`` for matrices. For vectors, use ``xp.linalg.vector_norm``.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ For general ezmsg tutorials and guides, visit `ezmsg.org <https://www.ezmsg.org>
:caption: Contents:

guides/classification
guides/array_api
api/index


Expand Down
36 changes: 28 additions & 8 deletions src/ezmsg/learn/dim_reduce/adaptive_decomp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
"""Adaptive decomposition transformers (PCA, NMF).

.. note::
This module supports the Array API standard via
``array_api_compat.get_namespace()``. Reshaping and output allocation
use Array API operations; a NumPy boundary is applied before sklearn
``partial_fit``/``transform`` calls.
"""

import math
import typing

import ezmsg.core as ez
import numpy as np
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
Expand Down Expand Up @@ -128,6 +139,8 @@ def _process(self, message: AxisArray) -> AxisArray:
if in_dat.shape[ax_idx] == 0:
return self._state.template

xp = get_namespace(in_dat)

# Re-order axes
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
if message.dims != sorted_dims_exp:
Expand All @@ -137,16 +150,20 @@ def _process(self, message: AxisArray) -> AxisArray:
pass

# fold [iter_axis] + off_targ_axes together and fold targ_axes together
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
in_dat = in_dat.reshape((-1, d2))
d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
in_dat = xp.reshape(in_dat, (-1, d2))

replace_kwargs = {
"axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
}

# Transform data
# Transform data — sklearn needs numpy
if hasattr(self._state.estimator, "components_"):
decomp_dat = self._state.estimator.transform(in_dat).reshape((-1,) + self._state.template.data.shape[1:])
in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
decomp_dat = self._state.estimator.transform(in_np)
# Convert back to source namespace
decomp_dat = xp.asarray(decomp_dat) if not is_numpy_array(in_dat) else decomp_dat
decomp_dat = xp.reshape(decomp_dat, (-1,) + self._state.template.data.shape[1:])
replace_kwargs["data"] = decomp_dat

return replace(self._state.template, **replace_kwargs)
Expand All @@ -165,18 +182,21 @@ def partial_fit(self, message: AxisArray) -> None:
if in_dat.shape[ax_idx] == 0:
return

xp = get_namespace(in_dat)

# Re-order axes if needed
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
if message.dims != sorted_dims_exp:
# TODO: Implement axes transposition if needed
pass

# fold [iter_axis] + off_targ_axes together and fold targ_axes together
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
in_dat = in_dat.reshape((-1, d2))
d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
in_dat = xp.reshape(in_dat, (-1, d2))

# Fit the estimator
self._state.estimator.partial_fit(in_dat)
# Fit the estimator — sklearn needs numpy
in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
self._state.estimator.partial_fit(in_np)


class IncrementalPCASettings(AdaptiveDecompSettings):
Expand Down
Loading