Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce blackjax sampling memory usage #7407

Merged
merged 6 commits into from
Jul 13, 2024
Merged

Reduce blackjax sampling memory usage #7407

merged 6 commits into from
Jul 13, 2024

Conversation

junpenglao
Copy link
Member

@junpenglao junpenglao commented Jul 9, 2024

Description

Reduce blackjax sampling memory usage by not outputting the warm up diagnostics

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7407.org.readthedocs.build/en/7407/

... by not outputing the warmup diagnositics
@junpenglao
Copy link
Member Author

Will need to upgrade jaxlib requirement first (conda-forge/jaxlib-feedstock#272)

Copy link

codecov bot commented Jul 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.18%. Comparing base (641a60b) to head (e81d828).
Report is 97 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7407   +/-   ##
=======================================
  Coverage   92.18%   92.18%           
=======================================
  Files         103      103           
  Lines       17258    17259    +1     
=======================================
+ Hits        15909    15910    +1     
  Misses       1349     1349           
Files with missing lines Coverage Δ
pymc/sampling/jax.py 94.03% <100.00%> (+0.02%) ⬆️

pyproject.toml Outdated Show resolved Hide resolved
@junpenglao junpenglao merged commit c8b22df into main Jul 13, 2024
22 checks passed
@junpenglao junpenglao deleted the blackjax_memory branch July 13, 2024 07:33
mkusnetsov pushed a commit to mkusnetsov/pymc that referenced this pull request Oct 26, 2024
* Reduce blackjax sampling memory usage 

... by not outputing the warmup diagnositics

* Update jax env

* fix pre-commit

* skip also RuntimeWarning

* ping jax versions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants