Skip to content

Commit 03dfb2b

Browse files
committed
init
1 parent 70eec9a commit 03dfb2b

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/cellflow/solvers/_otfm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,13 @@ def predict(
199199
-------
200200
The push-forward distribution of ``x`` under condition ``condition``.
201201
"""
202+
x_pred = self._predict_jit(x, condition, rng, **kwargs)
203+
return np.array(x_pred)
204+
205+
def _predict_jit(
206+
self, x: ArrayLike, condition: dict[str, ArrayLike], rng: jax.Array | None = None, **kwargs: Any
207+
) -> ArrayLike:
208+
"""See :meth:`OTFlowMatching.predict`."""
202209
kwargs.setdefault("dt0", None)
203210
kwargs.setdefault("solver", diffrax.Tsit5())
204211
kwargs.setdefault("stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5))
@@ -226,7 +233,10 @@ def solve_ode(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise:
226233
return result.ys[0]
227234

228235
x_pred = jax.jit(jax.vmap(solve_ode, in_axes=[0, None, None]))(x, condition, encoder_noise)
229-
return np.array(x_pred)
236+
return x_pred
237+
238+
def predict_batch(self, x: ArrayLike, condition: dict[str, ArrayLike], rng: jax.Array | None = None, **kwargs: Any) -> ArrayLike:
239+
pass
230240

231241
@property
232242
def is_trained(self) -> bool:

0 commit comments

Comments
 (0)