Skip to content

Commit db6dc69

Browse files
committed
needs testing
1 parent a3c43d3 commit db6dc69

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/cellflow/solvers/_otfm.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def predict(
202202
x_pred = self._predict_jit(x, condition, rng, **kwargs)
203203
return np.array(x_pred)
204204

205-
@jax.jit
205+
206206
def _predict_jit(
207207
self, x: ArrayLike, condition: dict[str, ArrayLike], rng: jax.Array | None = None, **kwargs: Any
208208
) -> ArrayLike:
@@ -236,13 +236,15 @@ def solve_ode(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise:
236236
x_pred = jax.jit(jax.vmap(solve_ode, in_axes=[0, None, None]))(x, condition, encoder_noise)
237237
return x_pred
238238

239+
239240
def predict_batch(self, x_dict: dict[str, ArrayLike], condition_dict: dict[str, dict[str, ArrayLike]], rng: jax.Array | None = None, **kwargs: Any) -> dict[str, ArrayLike]:
240-
batched_predict = jax.vmap(
241-
self._predict_jit,
242-
in_axes=(0, dict.fromkeys(self.condition_keys, 0))
243-
)
244241
keys = sorted(x_dict.keys())
245242
condition_keys = sorted(set().union(*(condition_dict[k].keys() for k in keys)))
243+
_predict_jit = jax.jit(lambda x, condition: self._predict_jit(x, condition, rng, **kwargs))
244+
batched_predict = jax.vmap(
245+
_predict_jit,
246+
in_axes=(0, dict.fromkeys(condition_keys, 0))
247+
)
246248
src_inputs = jnp.stack([x_dict[k] for k in keys], axis=0)
247249
batched_conditions = {}
248250
for cond_key in condition_keys:

0 commit comments

Comments
 (0)