Skip to content

Commit dc68957

Browse files
dkweiss31andyElkingpatrick-kidger
authored
Save fix for t0==t1 (#494)
* Langevin PR (#453) * Langevin PR * Minor fixes * removed the SORT solver (superseded by QUICSORT) * made LangevinTerm.term a static field * temporary fix for _term_compatible and LangevinTerm * Fixed LangevinTerm YAAAYYYYY * Nits * Added Langevin docs, a Langevin example and backwards in time test * Fixed Patrick's comments * langevin -> underdamped_langevin * round of small fixes * check langevin drift term and diffusion term have same args * added scan_trick in QUICSORT and ShOULD * using RuntimeError for when ULD args have wrong structure * small fixes * tidy-ups * Split SDE tests in half, to try and avoid GitHub runner issues? * Added effects_barrier to fix test issue with JAX 0.4.33+ * small fix of docs in all three and a return type in quicsort * bump doc building pipeline * Compatibility with JAX 0.4.36, which removes ConcreteArray * using a fori_loop to save states in edge case t0==t1 * added case for saving t0 data, which was also not getting updated. Added a test * using while_loop, ran into issues with reverse-mode diff using the fori_loop * bug fix for cases when t0=True * simplified logic for saving, no loop necessary * added vmap test * using a fori_loop to save states in edge case t0==t1 * added case for saving t0 data, which was also not getting updated. Added a test * using while_loop, ran into issues with reverse-mode diff using the fori_loop * bug fix for cases when t0=True * simplified logic for saving, no loop necessary * added vmap test * fix t1 out of bounds issue * fix for steps: don't want to update those values if t0==t1 since we didn't take any steps. Added test --------- Co-authored-by: Andraž Jelinčič <66168650+andyElking@users.noreply.github.com> Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Co-authored-by: andyElking <andraz.jelincic@gmail.com>
1 parent ba09fba commit dc68957

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

diffrax/_integrate.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,9 +775,56 @@ def _save_t1(subsaveat, save_state):
775775
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
776776
return save_state
777777

778+
def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
779+
if subsaveat.ts is not None:
780+
out_size = 1 if subsaveat.t0 else 0
781+
out_size += 1 if subsaveat.t1 and not subsaveat.steps else 0
782+
out_size += len(subsaveat.ts)
783+
ys = jtu.tree_map(
784+
lambda y: jnp.stack([y] * out_size),
785+
subsaveat.fn(t0, yfinal, args),
786+
)
787+
ts = jnp.full(out_size, t0)
788+
if subsaveat.steps:
789+
ysteps = jtu.tree_map(
790+
lambda y: jnp.stack([y] * max_steps),
791+
subsaveat.fn(t0, jnp.full_like(yfinal, jnp.inf), args),
792+
)
793+
ys = jtu.tree_map(
794+
lambda _ys, _ysteps: jnp.concatenate([_ys, _ysteps], axis=0),
795+
ys,
796+
ysteps,
797+
)
798+
ts = jnp.concatenate((ts, jnp.full(max_steps, jnp.inf)))
799+
save_state = SaveState(
800+
saveat_ts_index=out_size,
801+
ts=ts,
802+
ys=ys,
803+
save_index=out_size,
804+
)
805+
return save_state
806+
778807
save_state = jtu.tree_map(
779808
_save_t1, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat
780809
)
810+
811+
# if t0 == t1 then we don't enter the integration loop. In this case we have to
812+
# manually update the saved ts and ys if we want to save at "intermediate"
813+
# times specified by saveat.subs.ts
814+
save_state = jax.lax.cond(
815+
eqxi.unvmap_any(t0 == t1),
816+
lambda __save_state: jax.lax.cond(
817+
t0 == t1,
818+
lambda _save_state: jtu.tree_map(
819+
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
820+
),
821+
lambda _save_state: _save_state,
822+
__save_state,
823+
),
824+
lambda __save_state: __save_state,
825+
save_state,
826+
)
827+
781828
final_state = eqx.tree_at(
782829
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
783830
)

test/test_saveat_solution.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,91 @@ def test_saveat_solution():
147147
assert sol.result == diffrax.RESULTS.successful
148148

149149

150+
@pytest.mark.parametrize("subs", [True, False])
151+
def test_t0_eq_t1(subs):
152+
y0 = jnp.array([2.0])
153+
ts = jnp.linspace(1.0, 1.0, 3)
154+
max_steps = 256
155+
if subs:
156+
get0 = diffrax.SubSaveAt(
157+
ts=ts,
158+
t1=True,
159+
)
160+
get1 = diffrax.SubSaveAt(
161+
t0=True,
162+
ts=ts,
163+
)
164+
get2 = diffrax.SubSaveAt(
165+
t0=True,
166+
ts=ts,
167+
steps=True,
168+
)
169+
subs = (get0, get1, get2)
170+
saveat = diffrax.SaveAt(subs=subs)
171+
else:
172+
saveat = diffrax.SaveAt(t0=True, t1=True, ts=ts)
173+
term = diffrax.ODETerm(lambda t, y, args: y)
174+
sol = diffrax.diffeqsolve(
175+
term,
176+
t0=ts[0],
177+
t1=ts[-1],
178+
y0=y0,
179+
dt0=0.1,
180+
solver=diffrax.Dopri5(),
181+
saveat=saveat,
182+
max_steps=max_steps,
183+
)
184+
if subs:
185+
compare = jnp.full((len(ts) + 1, *y0.shape), y0)
186+
compare_2 = jnp.concatenate(
187+
(compare, jnp.full((max_steps, *y0.shape), jnp.inf))
188+
)
189+
assert tree_allclose(sol.ys[0], compare) # pyright: ignore
190+
assert tree_allclose(sol.ys[1], compare) # pyright: ignore
191+
assert tree_allclose(sol.ys[2], compare_2) # pyright: ignore
192+
else:
193+
compare = jnp.full((len(ts) + 2, *y0.shape), y0)
194+
assert tree_allclose(sol.ys, compare)
195+
196+
197+
@pytest.mark.parametrize("subs", [True, False])
198+
def test_vmap_t0_eq_t1(subs):
199+
ntsave = 4
200+
y0 = jnp.array([2.0])
201+
term = diffrax.ODETerm(lambda t, y, args: y)
202+
203+
def _solve(tf):
204+
ts = jnp.linspace(0.0, tf, ntsave)
205+
get0 = diffrax.SubSaveAt(
206+
ts=ts,
207+
t1=True,
208+
)
209+
get1 = diffrax.SubSaveAt(
210+
t0=True,
211+
ts=ts,
212+
)
213+
subs = (get0, get1)
214+
saveat = diffrax.SaveAt(subs=subs)
215+
return diffrax.diffeqsolve(
216+
term,
217+
t0=ts[0],
218+
t1=ts[-1],
219+
y0=y0,
220+
dt0=0.1,
221+
solver=diffrax.Dopri5(),
222+
saveat=saveat,
223+
)
224+
225+
compare = jnp.full((ntsave + 1, *y0.shape), y0)
226+
sol = jax.vmap(_solve)(jnp.array([0.0, 1.0]))
227+
assert tree_allclose(sol.ys[0][0], compare) # pyright: ignore
228+
assert tree_allclose(sol.ys[1][0], compare) # pyright: ignore
229+
230+
regular_solve = _solve(1.0)
231+
assert tree_allclose(sol.ys[0][1], regular_solve.ys[0]) # pyright: ignore
232+
assert tree_allclose(sol.ys[1][1], regular_solve.ys[1]) # pyright: ignore
233+
234+
150235
def test_trivial_dense():
151236
term = diffrax.ODETerm(lambda t, y, args: -0.5 * y)
152237
y0 = jnp.array([2.1])

0 commit comments

Comments
 (0)