Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pmdarima/arima/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numbers
import warnings
from scipy.stats import gaussian_kde, norm
from sklearn.utils.validation import check_array
from statsmodels import api as sm

from . import _validation as val
Expand All @@ -25,7 +24,7 @@
from ..compat import matplotlib as mpl_compat
from ..utils import if_has_delegate, is_iterable, check_endog, check_exog
from ..utils.visualization import _get_plt
from ..utils.array import diff_inv, diff
from ..utils.array import diff_inv, diff, check_array

# Get the version
import pmdarima
Expand Down Expand Up @@ -804,7 +803,7 @@ def predict(self,
# The confidence intervals may be a Pandas frame if it comes from
# SARIMAX & we want Numpy. We will to duck type it so we don't add
# new explicit requirements for the package
return f, check_array(conf_int, force_all_finite=False)
return f, check_array(conf_int, ensure_all_finite=False)
return f

def __getstate__(self):
Expand Down
40 changes: 30 additions & 10 deletions pmdarima/utils/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,26 @@
from ..compat import DTYPE
from ._array import C_intgrt_vec

import sklearn
from sklearn.utils import check_array as _check_array


# patched version of sklearn.utils.validation.check_array
# until sklearn minimal version for this project is 1.6
# to remove deprecation warning and ensure compatibility with sklearn 1.8
sklearn_ver = list(map(int, sklearn.__version__.split(".")))
_new_check_array_api = sklearn_ver[0] >= 1 and sklearn_ver[1] >= 6


def check_array(*args, ensure_all_finite=False, **kwargs):
if _new_check_array_api:
return _check_array(*args, ensure_all_finite=ensure_all_finite,
**kwargs)
else:
return _check_array(*args, force_all_finite=ensure_all_finite,
**kwargs)


__all__ = [
'as_series',
'c',
Expand Down Expand Up @@ -176,10 +196,10 @@ def check_endog(
y : np.ndarray or pd.Series, shape=(n_samples,)
A 1d numpy ndarray
"""
endog = skval.check_array(
endog = check_array(
y,
ensure_2d=False,
force_all_finite=force_all_finite,
ensure_all_finite=force_all_finite,
copy=copy,
dtype=dtype,
)
Expand Down Expand Up @@ -215,7 +235,7 @@ def check_exog(X, dtype=DTYPE, copy=True, force_all_finite=True):
Whether a forced copy will be triggered. If copy=False, a copy might
still be triggered by a conversion.

force_all_finite : bool, optional (default=True)
force_all_finite: bool, optional (default=True)
Whether to raise an error on np.inf and np.nan in an array. The
possibilities are:

Expand All @@ -239,12 +259,12 @@ def check_exog(X, dtype=DTYPE, copy=True, force_all_finite=True):
return X

# otherwise just a pass-through to the scikit-learn method
return skval.check_array(
return check_array(
X,
ensure_2d=True,
dtype=DTYPE,
copy=copy,
force_all_finite=force_all_finite,
ensure_all_finite=force_all_finite,
)


Expand Down Expand Up @@ -330,7 +350,7 @@ def diff(x, lag=1, differences=1):
if any(v < 1 for v in (lag, differences)):
raise ValueError('lag and differences must be positive (> 0) integers')

x = skval.check_array(x, ensure_2d=False, dtype=DTYPE, copy=False)
x = check_array(x, ensure_2d=False, dtype=DTYPE, copy=False)
fun = _diff_vector if x.ndim == 1 else _diff_matrix
res = x

Expand Down Expand Up @@ -385,11 +405,11 @@ def _diff_inv_matrix(x, lag, differences, xi):
if xi is None:
xi = np.zeros((lag * differences, m), dtype=DTYPE)
else:
xi = skval.check_array(
xi = check_array(
xi,
dtype=DTYPE,
copy=False,
force_all_finite=False,
ensure_all_finite=False,
ensure_2d=True,
)
if xi.shape != (lag * differences, m):
Expand Down Expand Up @@ -471,11 +491,11 @@ def diff_inv(x, lag=1, differences=1, xi=None):
----------
.. [1] https://stat.ethz.ch/R-manual/R-devel/library/stats/html/diffinv.html
""" # noqa: E501
x = skval.check_array(
x = check_array(
x,
dtype=DTYPE,
copy=False,
force_all_finite=False,
ensure_all_finite=False,
ensure_2d=False,
)

Expand Down