Skip to content

Commit 855c5d9

Browse files
revert duplicate removal
1 parent 235bc0e commit 855c5d9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/cellflow/solvers/_genot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def _predict_jit(
285285
x: ArrayLike,
286286
condition: dict[str, ArrayLike] | None = None,
287287
rng: ArrayLike | None = None,
288+
rng_genot: ArrayLike | None = None,
288289
**kwargs: Any,
289290
) -> ArrayLike | tuple[ArrayLike, diffrax.Solution]:
290291
kwargs.setdefault("dt0", None)
@@ -295,7 +296,8 @@ def _predict_jit(
295296
use_mean = rng is None or self.condition_encoder_mode == "deterministic"
296297
rng = utils.default_prng_key(rng)
297298
encoder_noise = jnp.zeros(noise_dim) if use_mean else jax.random.normal(rng, noise_dim)
298-
latent = self.latent_noise_fn(rng, (x.shape[0],))
299+
rng_genot = utils.default_prng_key(rng_genot)
300+
latent = self.latent_noise_fn(rng_genot, (x.shape[0],))
299301

300302
def vf(t: float, x: jnp.ndarray, args: tuple[dict[str, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
301303
params = self.vf_state.params

0 commit comments

Comments
 (0)