Skip to content

ENH: add quantile #341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
nunique
one_hot
pad
quantile
setdiff1d
sinc
```
3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, one_hot, pad
from ._delegation import isclose, one_hot, pad, quantile
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand Down Expand Up @@ -36,6 +36,7 @@
"nunique",
"one_hot",
"pad",
"quantile",
"setdiff1d",
"sinc",
]
99 changes: 98 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Literal

from ._lib import _funcs
from ._lib._quantile import quantile as _quantile
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
Expand All @@ -18,7 +19,7 @@
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array, DType

__all__ = ["isclose", "one_hot", "pad"]
__all__ = ["isclose", "one_hot", "pad", "quantile"]


def isclose(
Expand Down Expand Up @@ -247,3 +248,99 @@ def pad(
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def quantile(
x: Array,
q: Array | float,
/,
*,
axis: int | None = None,
keepdims: bool | None = None,
method: str = "linear",
xp: ModuleType | None = None,
) -> Array:
"""
Compute the q-th quantile(s) of the data along the specified axis.

Parameters
----------
x : array of real numbers
Data array.
q : array of float
Probability or sequence of probabilities of the quantiles to compute.
Values must be between 0 and 1 (inclusive). Must have length 1 along
`axis` unless ``keepdims=True``.
axis : int or None, default: None
Axis along which the quantiles are computed. ``None`` ravels both `x`
and `q` before performing the calculation.
keepdims : bool or None, default: None
By default, the axis will be reduced away if possible
(i.e. if there is exactly one element of `q` per axis-slice of `x`).
If `keepdims` is set to True, the axes which are reduced are left in the
result as dimensions with size one. With this option, the result will
broadcast correctly against the original array `x`.
If `keepdims` is set to False, the axis will be reduced away if possible,
and an error will be raised otherwise.
method : str, default: 'linear'
The method to use for estimating the quantile. The available options are:
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default),
'median_unbiased', 'normal_unbiased'.
xp : array_namespace, optional
The standard-compatible namespace for `x` and `q`. Default: infer.

Returns
-------
array
An array with the quantiles of the data.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([[10, 8, 7, 5, 4], [0, 1, 2, 3, 5]])
>>> xpx.quantile(x, 0.5, axis=-1)
Array([7., 2.], dtype=array_api_strict.float64)
>>> xpx.quantile(x, [0.25, 0.75], axis=-1)
Array([[5., 8.],
[1., 3.]], dtype=array_api_strict.float64)
"""
# We only support a subset of the methods supported by scipy.stats.quantile.
# So we need to perform the validation here.
methods = {
"inverted_cdf",
"averaged_inverted_cdf",
"closest_observation",
"hazen",
"interpolated_inverted_cdf",
"linear",
"median_unbiased",
"normal_unbiased",
"weibull",
}
if method not in methods:
msg = f"`method` must be one of {methods}"
raise ValueError(msg)

xp = array_namespace(x, q) if xp is None else xp

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scikit-learn/scikit-learn#31671 (comment) suggests that delegation to some existing array libraries may be desirable here

if is_dask_namespace(xp):
return xp.quantile(x, q, axis=axis, keepdims=keepdims, method=method)

try:
import scipy # type: ignore[import-untyped]
from packaging import version

# The quantile function in scipy 1.16 supports array API directly, no need
# to delegate
if version.parse(scipy.__version__) >= version.parse("1.17"): # pyright: ignore[reportUnknownArgumentType]
from scipy.stats import ( # type: ignore[import-untyped]
quantile as scipy_quantile,
)

return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method)
except (ImportError, AttributeError):
pass

return _quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp)
149 changes: 149 additions & 0 deletions src/array_api_extra/_lib/_quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Quantile implementation."""

from types import ModuleType
from typing import cast

from ._at import at
from ._utils import _compat
from ._utils._compat import array_namespace
from ._utils._typing import Array


def quantile(
x: Array,
q: Array | float,
/,
*,
axis: int | None = None,
keepdims: bool | None = None,
method: str = "linear",
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""
if xp is None:
xp = array_namespace(x, q)

q_is_scalar = isinstance(q, int | float)
if q_is_scalar:
q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x))
q_arr = cast(Array, q)

if not xp.isdtype(x.dtype, ("integral", "real floating")):
msg = "`x` must have real dtype."
raise ValueError(msg)
if not xp.isdtype(q_arr.dtype, "real floating"):
msg = "`q` must have real floating dtype."
raise ValueError(msg)

# Promote to common dtype
x = xp.astype(x, xp.float64)
q_arr = xp.asarray(q_arr, dtype=xp.float64, device=_compat.device(x))

dtype = x.dtype
axis_none = axis is None
ndim = max(x.ndim, q_arr.ndim)

if axis_none:
x = xp.reshape(x, (-1,))
q_arr = xp.reshape(q_arr, (-1,))
axis = 0
elif not isinstance(axis, int): # pyright: ignore[reportUnnecessaryIsInstance]
msg = "`axis` must be an integer or None."
raise ValueError(msg)
elif axis >= ndim or axis < -ndim:
msg = "`axis` is not compatible with the shapes of the inputs."
raise ValueError(msg)
else:
axis = int(axis)

if keepdims not in {None, True, False}:
msg = "If specified, `keepdims` must be True or False."
raise ValueError(msg)

if x.shape[axis] == 0:
shape = list(x.shape)
shape[axis] = 1
x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x))

y = xp.sort(x, axis=axis)

# Move axis to the end for easier processing
y = xp.moveaxis(y, axis, -1)
if not (q_is_scalar or q_arr.ndim == 0):
q_arr = xp.moveaxis(q_arr, axis, -1)

n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y))

# Validate that q values are in the range [0, 1]
if xp.any((q_arr < 0) | (q_arr > 1)):
msg = "`q` must contain values between 0 and 1 inclusive."
raise ValueError(msg)

res = _quantile_hf(y, q_arr, n, method, xp)

# Reshape per axis/keepdims
if axis_none and keepdims:
shape = (1,) * (ndim - 1) + res.shape
res = xp.reshape(res, shape)
axis = -1

# Move axis back to original position
res = xp.moveaxis(res, -1, axis)

if not keepdims and res.shape[axis] == 1:
res = xp.squeeze(res, axis=axis)

if res.ndim == 0:
return res[()]
return res


def _quantile_hf(
y: Array, p: Array, n: Array, method: str, xp: ModuleType
) -> Array: # numpydoc ignore=PR01,RT01
"""Helper function for Hyndman-Fan quantile method."""
ms: dict[str, Array | int | float] = {
"inverted_cdf": 0,
"averaged_inverted_cdf": 0,
"closest_observation": -0.5,
"interpolated_inverted_cdf": 0,
"hazen": 0.5,
"weibull": p,
"linear": 1 - p,
"median_unbiased": p / 3 + 1 / 3,
"normal_unbiased": p / 4 + 3 / 8,
}
m = ms[method]

jg = p * n + m - 1
# Convert both to integers, the type of j and n must be the same
# for us to be able to `xp.clip` them.
j = xp.astype(jg // 1, xp.int64)
n = xp.astype(n, xp.int64)
g = jg % 1

if method == "inverted_cdf":
g = xp.astype((g > 0), jg.dtype)
elif method == "averaged_inverted_cdf":
g = (1 + xp.astype((g > 0), jg.dtype)) / 2
elif method == "closest_observation":
g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype)
if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}:
g = xp.asarray(g)
g = at(g, jg < 0).set(0)
g = at(g, j < 0).set(0)
j = xp.clip(j, 0, n - 1)
jp1 = xp.clip(j + 1, 0, n - 1)

# Broadcast indices to match y shape except for the last axis
if y.ndim > 1:
# Create broadcast shape for indices
broadcast_shape = [*y.shape[:-1], 1]
j = xp.broadcast_to(j, broadcast_shape)
jp1 = xp.broadcast_to(jp1, broadcast_shape)
g = xp.broadcast_to(g, broadcast_shape)

res = (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis(
y, jp1, axis=-1
)
return res # noqa: RET504
72 changes: 72 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
nunique,
one_hot,
pad,
quantile,
setdiff1d,
sinc,
)
Expand All @@ -43,6 +44,7 @@
lazy_xp_function(nunique)
lazy_xp_function(one_hot)
lazy_xp_function(pad)
lazy_xp_function(quantile)
# FIXME calls in1d which calls xp.unique_values without size
lazy_xp_function(setdiff1d, jax_jit=False)
lazy_xp_function(sinc)
Expand Down Expand Up @@ -1162,3 +1164,73 @@ def test_device(self, xp: ModuleType, device: Device):

def test_xp(self, xp: ModuleType):
xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))


class TestQuantile:
def test_basic(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
actual = quantile(x, 0.5)
expect = xp.asarray(3.0)
xp_assert_close(actual, expect)

def test_multiple_quantiles(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
actual = quantile(x, xp.asarray([0.25, 0.5, 0.75]))
expect = xp.asarray([2.0, 3.0, 4.0])
xp_assert_close(actual, expect)

def test_2d_axis(self, xp: ModuleType):
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
actual = quantile(x, 0.5, axis=0)
expect = xp.asarray([2.5, 3.5, 4.5])
xp_assert_close(actual, expect)

def test_2d_axis_keepdims(self, xp: ModuleType):
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
actual = quantile(x, 0.5, axis=0, keepdims=True)
expect = xp.asarray([[2.5, 3.5, 4.5]])
xp_assert_close(actual, expect)

def test_methods(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
methods = ["linear", "hazen", "weibull"]
for method in methods:
actual = quantile(x, 0.5, method=method)
# All methods should give reasonable results
assert 2.5 <= float(actual) <= 3.5

def test_edge_cases(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
# q = 0 should give minimum
actual = quantile(x, 0.0)
expect = xp.asarray(1.0)
xp_assert_close(actual, expect)

# q = 1 should give maximum
actual = quantile(x, 1.0)
expect = xp.asarray(5.0)
xp_assert_close(actual, expect)

def test_invalid_q(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
# q > 1 should raise
with pytest.raises(
ValueError, match="`q` must contain values between 0 and 1 inclusive"
):
_ = quantile(x, 1.5)

with pytest.raises(
ValueError, match="`q` must contain values between 0 and 1 inclusive"
):
_ = quantile(x, -0.5)

def test_device(self, xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3, 4, 5], device=device)
actual = quantile(x, 0.5)
assert get_device(actual) == device

def test_xp(self, xp: ModuleType):
x = xp.asarray([1, 2, 3, 4, 5])
actual = quantile(x, 0.5, xp=xp)
expect = xp.asarray(3.0)
xp_assert_close(actual, expect)
Loading