Description
Some work has been started on sklearn
to interface estimators that are fully written using only typical array functions with array libraries that implement the Array API, see scikit-learn/scikit-learn#22554
In this effort, the default data validation function sklearn.utils.validation.check_array
has been adapted to all array libraries that implement the array API.
For the plugin system we're considering at scikit-learn/scikit-learn#24497 along with our plugin at https://github.com/soda-inria/sklearn-numba-dpex I've found that it could be interesting to re-use sklearn.utils.validation.check_array
on usm_ndarray
inputs with sklearn validation rules and it might also prevent unnecessary data copies.
I've found that currently a usm_ndarray
will fail the check_array
for those two reasons:
- requires
dpctl.tensor.isfinite
to be implemented - requires the
.__array_namespace__
attribute ofusm_ndarray
array to returndpctl.tensor
rather thanNone
In the meantime it's possible to work around those two missing features with:
import dpctl.tensor as dpt
from sklearn.utils.validation import check_array
from sklearn import config_context
array = dpctl.tensor.asarray(np.arange(12).reshape(4,3))
with config_context(array_api_dispatch=True, assume_finite=True): # workaround 1: assume_finite=True
tensor._set_namespace(dpt) # workaround 2: manually call ._set_namespace
checked_array = check_array(array, accept_sparse=False, dtype=[np.float32, np.float64], order="C", copy=False, force_all_finite=True)