Skip to content

Commit 8bc215c

Browse files
committed
Fix CubicInterpolation for the case where ys = {}. Add unit tests for both cubic and linear interpolation checking this case
1 parent 0b93a3c commit 8bc215c

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

diffrax/global_interpolation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,10 @@ def backward_hermite_coefficients(
750750

751751
ts = _check_ts(ts)
752752
fn = ft.partial(_backward_hermite_coefficients, fill_forward_nans_at_end, ts)
753+
754+
if len(jtu.tree_leaves(ys)) == 0:
755+
return ({}, {}, {}, {})
756+
753757
if deriv0 is None:
754758
if replace_nans_at_start is None:
755759
coeffs = jtu.tree_map(fn, ys)

test/test_global_interpolation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,7 @@ def test_interpolation_classes(mode, getkey):
262262
jnp.array([0.0, 2.0, 3.0, 3.1, 4.0, 4.1, 5.0, 5.1]),
263263
]
264264
_make = lambda: jrandom.normal(getkey(), (length, num_channels))
265-
ys_ = [
266-
_make(),
267-
[_make(), {"a": _make(), "b": _make()}],
268-
]
265+
ys_ = [_make(), [_make(), {"a": _make(), "b": _make()}], {}]
269266
for ts in ts_:
270267
assert len(ts) == length
271268
for ys in ys_:

0 commit comments

Comments
 (0)