Skip to content

Commit ffc311c

Browse files
committed
fix: prune nan idxs in output when omitting nans
1 parent 7e2f14a commit ffc311c

File tree

3 files changed

+53
-11
lines changed

3 files changed

+53
-11
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ module-name = "tsdownsample._rust._tsdownsample_rs" # The path to place the comp
3939

4040
# Linting
4141
[tool.ruff]
42-
select = ["E", "F", "I"]
4342
line-length = 88
44-
extend-select = ["Q"]
45-
ignore = ["E402", "F403"]
43+
lint.select = ["E", "F", "I"]
44+
lint.extend-select = ["Q"]
45+
lint.ignore = ["E402", "F403"]
4646

4747
# Formatting
4848
[tool.black]

tests/test_tsdownsample.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ def generate_rust_downsamplers() -> Iterable[AbstractDownsampler]:
4444
yield downsampler
4545

4646

47-
def generate_rust_nan_downsamplers() -> Iterable[AbstractDownsampler]:
48-
for downsampler in RUST_NAN_DOWNSAMPLERS:
49-
yield downsampler
50-
51-
5247
def generate_all_downsamplers() -> Iterable[AbstractDownsampler]:
5348
for downsampler in RUST_DOWNSAMPLERS + RUST_NAN_DOWNSAMPLERS + OTHER_DOWNSAMPLERS:
5449
yield downsampler
@@ -106,7 +101,7 @@ def test_rust_downsampler(downsampler: AbstractDownsampler):
106101
assert s_downsampled[-1] == len(arr) - 1
107102

108103

109-
@pytest.mark.parametrize("downsampler", generate_rust_nan_downsamplers())
104+
@pytest.mark.parametrize("downsampler", RUST_NAN_DOWNSAMPLERS)
110105
def test_rust_nan_downsampler(downsampler: AbstractRustNaNDownsampler):
111106
"""Test the Rust NaN downsamplers."""
112107
datapoints = generate_nan_datapoints()
@@ -360,3 +355,41 @@ def test_nan_minmaxlttb_downsampler():
360355
s_downsampled = NaNMinMaxLTTBDownsampler().downsample(arr, n_out=100)
361356
arr_downsampled = arr[s_downsampled]
362357
assert np.all(np.isnan(arr_downsampled[1:-1])) # first and last are not NaN
358+
359+
360+
@pytest.mark.parametrize("downsampler", RUST_DOWNSAMPLERS)
361+
def test_no_nans_omitted(downsampler: AbstractDownsampler):
362+
n = 10_000
363+
y = np.arange(n, dtype=np.float64)
364+
for i in range(1, 100):
365+
y[i + 100] = np.nan
366+
367+
s_downsampled = downsampler.downsample(y, n_out=1000)
368+
assert np.all(~np.isnan(y[s_downsampled]))
369+
s_downsampled = downsampler.downsample(y, n_out=1000, parallel=True)
370+
assert np.all(~np.isnan(y[s_downsampled]))
371+
372+
x = np.arange(n)
373+
s_downsampled = downsampler.downsample(x, y, n_out=1000)
374+
assert np.all(~np.isnan(y[s_downsampled]))
375+
s_downsampled = downsampler.downsample(x, y, n_out=1000, parallel=True)
376+
assert np.all(~np.isnan(y[s_downsampled]))
377+
378+
379+
@pytest.mark.parametrize("downsampler", RUST_NAN_DOWNSAMPLERS)
380+
def tests_nans_returned(downsampler: AbstractDownsampler):
381+
n = 10_000
382+
y = np.arange(n, dtype=np.float64)
383+
for i in range(1, 100):
384+
y[i + 100] = np.nan
385+
386+
s_downsampled = downsampler.downsample(y, n_out=1000)
387+
assert np.any(np.isnan(y[s_downsampled]))
388+
s_downsampled = downsampler.downsample(y, n_out=1000, parallel=True)
389+
assert np.any(np.isnan(y[s_downsampled]))
390+
391+
x = np.arange(n)
392+
s_downsampled = downsampler.downsample(x, y, n_out=1000)
393+
assert np.any(np.isnan(y[s_downsampled]))
394+
s_downsampled = downsampler.downsample(x, y, n_out=1000, parallel=True)
395+
assert np.any(np.isnan(y[s_downsampled]))

tsdownsample/downsampling_interface.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,10 @@ def _switch_mod_with_x_and_y(
335335
# TIMEDELTA -> i64 (timedelta64 is viewed as int64)
336336
raise ValueError(f"Unsupported data type (for x): {x_dtype}")
337337

338+
def _prune_nans(self, sampled_idxs: np.ndarray, y: np.ndarray) -> np.ndarray:
339+
"""Remove all nan indices."""
340+
return sampled_idxs[~np.isnan(y[sampled_idxs])]
341+
338342
def _downsample(
339343
self,
340344
x: Union[np.ndarray, None],
@@ -359,11 +363,11 @@ def _downsample(
359363
## Viewing the x-data as different dtype (if necessary)
360364
if x is None:
361365
downsample_f = self._switch_mod_with_y(y.dtype, mod)
362-
return downsample_f(y, n_out, **kwargs)
366+
return self._prune_nans(downsample_f(y, n_out, **kwargs), y)
363367
x = self._view_x(x)
364368
## Getting the appropriate downsample function
365369
downsample_f = self._switch_mod_with_x_and_y(x.dtype, y.dtype, mod)
366-
return downsample_f(x, y, n_out, **kwargs)
370+
return self._prune_nans(downsample_f(x, y, n_out, **kwargs), y)
367371

368372
def downsample(self, *args, n_out: int, parallel: bool = False, **kwargs):
369373
"""Downsample the data in x and y.
@@ -400,6 +404,11 @@ def _downsample_func_prefix(self) -> str:
400404
"""The prefix of the downsample functions in the rust module."""
401405
return NAN_DOWNSAMPLE_F
402406

407+
## Overriding the _prune_nans method to return the sampled indices without pruning
408+
def _prune_nans(self, sampled_idxs: np.ndarray, y: np.ndarray) -> np.ndarray:
409+
"""Remove all nan indices."""
410+
return sampled_idxs
411+
403412
def _switch_mod_with_y(
404413
self, y_dtype: np.dtype, mod: ModuleType, downsample_func: Optional[str] = None
405414
) -> Callable:

0 commit comments

Comments
 (0)