Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/median performance #20

Merged
merged 9 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
- name: Run tests
run: |
export PYTHONPATH="${PYTHONPATH}:/robustbase/"
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,13 @@ celerybeat.pid

# Environments
.env
.venv
.venv*
env/
venv/
ENV/
env.bak/
venv.bak/
.vscode/

# Spyder project settings
.spyderproject
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This package provides functions to calculate the following robust statistical es

```python
from robustbase.stats import Qn

x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# With bias correction
Expand All @@ -37,11 +37,11 @@ res = Qn(x, finite_corr=False) # result: 4.43828

```python
from robustbase.stats import Sn

x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# With bias correction
res = Sn(x) # result: 3.5778
res = Sn(x) # result: 3.5778

# Without bias correction
res = Sn(x, finite_corr=False) # result: 3.5778
Expand Down Expand Up @@ -75,7 +75,7 @@ For local development setup:
```sh
git clone https://github.com/deepak7376/robustbase
cd robustbase
pip install -r requirements.txt
pip install -r requirements.txt -r requirements-dev.txt
```

## Recent Changes
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest>=8.1.1
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ certifi>=2019.11.28
docutils>=0.15.2
numpy>=1.18.0
statistics>=1.0.3.5
pytest>=8.1.1
8 changes: 4 additions & 4 deletions robustbase/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .robustbase import Qn
from .robustbase import Sn
from .robustbase import iqr
from .robustbase import mad
from .robustbase import Qn # noqa: F401
from .robustbase import Sn # noqa: F401
from .robustbase import iqr # noqa: F401
from .robustbase import mad # noqa: F401


8 changes: 4 additions & 4 deletions robustbase/stats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .Qn import Qn
from .Sn import Sn
from .iqr import iqr
from .mad import mad
from .iqr import iqr # noqa: F401
from .mad import mad # noqa: F401
from .Qn import Qn # noqa: F401
from .Sn import Sn # noqa: F401
4 changes: 2 additions & 2 deletions robustbase/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .mean import mean
from .median import median
from .mean import mean # noqa: F401
from .median import median # noqa: F401
29 changes: 18 additions & 11 deletions robustbase/utils/median.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,29 @@ def median(x, low=False, high=False):

Parameters:
- x: list or array-like, numeric vector of observations.
- low: bool, if True, return the low median for even sample size.
- low: bool, if True, return the low median for even sample size. If ``True``, ``high`` is ignored.
- high: bool, if True, return the high median for even sample size.

Returns:
- float: Median value.
"""
sorted_x = np.sort(x)
n = len(sorted_x)

n = len(x)
if n == 0:
raise ValueError("Empty list provided.")


# for odd sample size, all three medians are the same
if n % 2 == 1:
return sorted_x[n // 2]
elif low:
return sorted_x[n // 2 - 1]
elif high:
return sorted_x[n // 2]
else:
return (sorted_x[n // 2 - 1] + sorted_x[n // 2]) / 2
return np.median(a=x)

# for even sample sizes, the median is the average of the two middle values if
# neither the low nor high median is requested
if not (low or high):
return np.median(a=x)

# otherwise, either the low or the high median are found via introselect
median_idx = n // 2
if low:
median_idx -= 1

return np.partition(a=x, kth=median_idx)[median_idx]
47 changes: 47 additions & 0 deletions tests/test_median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional, Tuple, Union

import numpy as np
import pytest

from robustbase.utils.median import median

X_EMPTY = []
X_ODD_N = [5.5, 3.2, -10.0, -2.1, 8.4]
X_EVEN_N = [5.5, 3.2, -10.0, -2.1, 8.4, 0.0]


@pytest.mark.parametrize("as_array", [False, True])
@pytest.mark.parametrize(
"comb",
[
(X_EMPTY, False, False, None),
(X_EMPTY, True, False, None),
(X_EMPTY, False, True, None),
(X_EMPTY, True, True, None),
(X_ODD_N, False, False, 3.2),
(X_ODD_N, True, False, 3.2),
(X_ODD_N, False, True, 3.2),
(X_ODD_N, True, True, 3.2),
(X_EVEN_N, False, False, 1.6),
(X_EVEN_N, True, False, 0.0),
(X_EVEN_N, False, True, 3.2),
(X_EVEN_N, True, True, 0.0),
],
)
def test_median(
comb: Tuple[Union[list, np.ndarray], bool, bool, Optional[float]],
as_array: bool,
):
x, low, high, expected = comb
if as_array:
x = np.array(x)

# for empty samples, an error should be raised
if expected is None:
with pytest.raises(ValueError):
median(x, low=low, high=high)

return

# otherwise, the expected median should be returned
assert median(x, low=low, high=high) == expected
Loading