Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce blackjax sampling memory usage #7407

Merged
merged 6 commits into from
Jul 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
ping jax versions
  • Loading branch information
junpenglao committed Jul 13, 2024
commit e81d8289ddbaa093874d8a6c8d0bdda6dfdc206e
9 changes: 5 additions & 4 deletions conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ dependencies:
- cloudpickle
- h5py>=2.7
# Jaxlib version must not be greater than jax version!
- blackjax
- jaxlib
- jax
- blackjax>=1.2.2
- jax>=0.4.28
- jaxlib>=0.4.28
- libblas=*=*mkl
- mkl-service
- numpy>=1.15.0
Expand All @@ -25,7 +25,8 @@ dependencies:
- networkx
- rich>=13.7.1
- threadpoolctl>=3.1.0
- scipy
# JAX is only compatible with Scipy 1.13.0 from >=0.4.26
- scipy>=1.13.0
- typing-extensions>=3.7.4
# Extra dependencies for testing
- ipython>=7.16
Expand Down
Loading