Collection of tips and tutorials for running JAX on Perlmutter
-
First Introduction to JAX |
Authors: @EiffL
Covers the basic concepts of JAX with a few examples and common gotchas. Presented in July 2021. -
DESC DC2 Telecon: Practical introduction to JAX |
Authors: @EiffL
Covers a few examples of JAX uses cases: implemeting Limber integration, Fisher forecasts, running parallel MCMCs, fitting galaxy light profiles. Presented in Dec. 2021.
Installing JAX on Perlmutter is easy if you follow these steps:
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.htmlAnd that's it, but note that to run properly, JAX will require that you load the following modules: cudnn/8.2.0 nccl/2.11.4 cudatoolkit.
To make sure the necessary modules are loaded when you run your notebooks on JupyterLab, you will then need to create a custom Jupyter kernel.
- Create a template kernel
python -m ipykernel install --user --name jax --display-name JAXThis will create a template kernel named JAX, which we now need to modify slightly.
Go to the newly created kernel directory:
cd $HOME/.local/share/jupyter/kernels/jax
ls
You should see in this directory a kernel.json which we will edit in the next step.
- Edit kernel with custom startup script
Open the kernel.json file and edit to the following:
{
"argv": [
"{resource_dir}/kernel-helper.sh",
"python",
"-m",
"ipykernel_launcher",
"-f",
"{connection_file}"
],
"display_name": "JAX",
"language": "python",
"metadata": {
"debugger": true
}
}Now, in addition, create a new file in the same directory named kernel-helper.sh with the following content:
#!/bin/bash -l
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
exec "$@"Give execution privileges to this file
chmod u+x kernel-helper.shAnd that should be it. Now when you launch the JAX kernel on Perlmutter you should be able to run your jax code without issue.