@@ -202,7 +202,7 @@ def predict(
202
202
x_pred = self ._predict_jit (x , condition , rng , ** kwargs )
203
203
return np .array (x_pred )
204
204
205
- @ jax . jit
205
+
206
206
def _predict_jit (
207
207
self , x : ArrayLike , condition : dict [str , ArrayLike ], rng : jax .Array | None = None , ** kwargs : Any
208
208
) -> ArrayLike :
@@ -236,13 +236,15 @@ def solve_ode(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise:
236
236
x_pred = jax .jit (jax .vmap (solve_ode , in_axes = [0 , None , None ]))(x , condition , encoder_noise )
237
237
return x_pred
238
238
239
+
239
240
def predict_batch (self , x_dict : dict [str , ArrayLike ], condition_dict : dict [str , dict [str , ArrayLike ]], rng : jax .Array | None = None , ** kwargs : Any ) -> dict [str , ArrayLike ]:
240
- batched_predict = jax .vmap (
241
- self ._predict_jit ,
242
- in_axes = (0 , dict .fromkeys (self .condition_keys , 0 ))
243
- )
244
241
keys = sorted (x_dict .keys ())
245
242
condition_keys = sorted (set ().union (* (condition_dict [k ].keys () for k in keys )))
243
+ _predict_jit = jax .jit (lambda x , condition : self ._predict_jit (x , condition , rng , ** kwargs ))
244
+ batched_predict = jax .vmap (
245
+ _predict_jit ,
246
+ in_axes = (0 , dict .fromkeys (condition_keys , 0 ))
247
+ )
246
248
src_inputs = jnp .stack ([x_dict [k ] for k in keys ], axis = 0 )
247
249
batched_conditions = {}
248
250
for cond_key in condition_keys :
0 commit comments