Skip to content

Commit db7e8d4

Browse files
committed
Fix CubicInterpolation for the case where ys = {}. Add unit tests for both cubic and linear interpolation checking this case
1 parent 71c42b1 commit db7e8d4

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

diffrax/_global_interpolation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,10 @@ def backward_hermite_coefficients(
758758

759759
ts = _check_ts(ts)
760760
fn = ft.partial(_backward_hermite_coefficients, fill_forward_nans_at_end, ts)
761+
762+
if len(jtu.tree_leaves(ys)) == 0:
763+
return ({}, {}, {}, {})
764+
761765
if deriv0 is None:
762766
if replace_nans_at_start is None:
763767
coeffs = jtu.tree_map(fn, ys)

test/test_global_interpolation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def test_interpolation_classes(mode, getkey):
261261
ys_ = [
262262
_make(),
263263
[_make(), {"a": _make(), "b": _make()}],
264+
{}
264265
]
265266
for ts in ts_:
266267
assert len(ts) == length

0 commit comments

Comments
 (0)