Skip to content

Commit b9bc6c2

Browse files
committed
pass rng from cellflow.predict
1 parent 7a69c5d commit b9bc6c2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/cellflow/model/_cellflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def predict(
619619
src = batch["source"]
620620
condition = batch.get("condition", None)
621621
out = jax.tree.map(
622-
functools.partial(self.solver.predict, **kwargs),
622+
functools.partial(self.solver.predict, rng=rng, **kwargs),
623623
src,
624624
condition, # type: ignore[attr-defined]
625625
)

0 commit comments

Comments
 (0)