A numerical relativity code written only in JAX.
JAX_NR implements the BSSN (Baumgarte-Shapiro-Shibata-Nakamura) formulation of Einstein's equations for evolving binary black hole spacetimes. The code leverages JAX's jit compilation and einsum operations.
The code is meant to be a toy data generation for a real binary black hole merger simulation. It does not use AMR, spectral methods, CCE and other highly advanced techniques for very precise gravitational wave extractions. Moreover, the simulation runs on float32 precision, it has a fixed uniform grid covering the whole domain and uses an implicit backward Euler solver instead of RK4 (very common in NR).
As for now, it doesn't support the use of more GPUs or CPU cores.
- BSSN Evolution: Full implementation of the BSSN formulation with two moving punctures.
- Binary Black Hole Initial Data: Bowen-York puncture initial data for non-rotating black holes.
-
Gravitational Wave Extraction: Weyl scalar
$\Psi_4$ computation with spin-weighted spherical harmonic decomposition. - Up to date: Using damping parameters and stability constraints from recent NR papers.
- Constraint Monitoring: Hamiltonian and momentum constraint violation tracking.
-
JAX Acceleration: Full JAX compatibility with
jax.jitandjnp.einsumfor maximum performance. - Fast GPU Solving Time: 30-40min on NVIDIA H200 GPU and a few hours on other GPUs. Up to 1 day or more on CPU.
git clone https://github.com/AndreiB137/JAX_NR.git
cd JAX_NRCPU-only installation:
pip install -r requirements.txtGPU installation (CUDA 13):
pip install -r requirements.txt
pip install -U "jax[cuda13]"GPU installation (CUDA 12):
pip install -r requirements.txt
pip install -U "jax[cuda12]"JAX_NR/
├── main.py # File to run the simulation
├── run_config.yml # Configuration file
├── initial_conditions/
│ ├── __init__.py
│ └── init_cond.py # Initial conditions Laplacian solver
├── bssn_evolution/
│ ├── __init__.py
│ ├── adm_variables.py # Converting BSSN variablest to ADM
│ ├── boundary_condition.py. # Sommerfeld boundary condition
│ ├── bssn.py # Evolution equations and solver
│ └── constraint_error.py # Utils for tracking constraint violations
│
├── wave_extraction/
│ ├── weyl_scalar_extraction.py # Ψ₄ computation
│ ├── grid_interpolator_3D.py # Lagrange interpolation
│ └── spin_weighted_spherical_harmonics.py # SWSH decomposition
├── finite_difference/ # Finite difference operators and curvature adjusted Kreiss-Oliger dissipation
├── tensor_operations/ # Utils for various tensor operations
└── utils/ # Small utils for BSSN evolution
Edit run_config.yml to set up your binary black hole parameters:
grid:
n: 213 # Grid points per dimension (n × n × n)
domain_width: 30.0 # Physical domain size in M
initial_conditions:
bh_mass1: 0.483 # Mass of first black hole
bh_mass2: 0.483 # Mass of second black hole
bh_pos1: [0.0, 3.257, 0.0] # Position of BH1
bh_pos2: [0.0, -3.257, 0.0] # Position of BH2
bh_momentum1: [-0.133, 0.0, 0.0] # Linear momentum of BH1
bh_momentum2: [0.133, 0.0, 0.0] # Linear momentum of BH2
evolution:
total_time: 250.0 # Total evolution time in M
dt: 0.0419015 # Time step (should satisfy CFL)
show_constraints: true # Monitor constraint violations
damping_args:
ko_damping: 0.05 # How powerful the KO dissipation should be
ko_strength: 0.5 # How strong to adjust this to the conformal factor W
lapse_damping: 0.8 # Slow start lapse condition
gauge_shift_damping: 2.0 # Damping in the shift evolution
metric_damping: -0.18 # Damping in the conformal metric, but helpful for the conformal Gamma constraint
momentum_damping: 0.04 # Damping in the conformal extrinsic curvature evolution for stability in the momentum constraint
gamma_damping: 0.9 # Damping in the conformal Gamma evolution
output:
initial_condition_save_dir: ... # Where to save the initial BSSN variables coming from the solver
evolution_save_dir: "..." # Path where the BSSN variables should be saved and used as checkpoint for later if needed
checkpoint_frequency: 500. # How often to save the BSSN variables
load_checkpoint: null. # Only if needed to resume the simulation for later
wave_extraction:
enabled: true
extraction_frequency: 2 # How often to extract the SWSH (-2, 2, 2) coefficient of Ψ₄
save_dir: ...
extraction_radius: 14.3 # Radius for wave extraction
n_theta: 512 # Angular resolution
n_phi: 512The initial data solver uses the puncture method to solve the Hamiltonian constraint.
This will generate the initial condition BSSN variables inside initial_condition_save_dir.
python main.py --config="/path_to/run_config.yml"python main.py --config="/path_to/run_config.yml"In wave_extraction.save_dir the (-2, 2, 2) coefficient of the Ψ₄ on the sphere with extraction radius given is saved in .npy files. Each one contains a single complex number at a particular step in the simulation.
- If using GPU, it requires at least 14 GB of available VRAM.
- Because the simulation is not using AMR, other initial conditions might not even converge.
- One of the key aspects that makes this simulation work is the use of automatic differentiation combined with the FD stencils. Instead of doing FD on the conformal Gamma constraint, chain rule is applied until we can use FD derivatives of the BSSN variables. Similarly for ADM variable calculations, where possible the derivatives are propagated by hand until the FD Jacobians or Hessians can be used.
- It works with float32 and this makes it fast for GPU execution.
Lapse contour slice at
Real part of the $\Psi_4$ (-2,2,2) coefficient, showing the gravitational wave signal from the binary black hole merger
If you use JAX_NR in your work, please cite:
@software{jax_nr,
author = {Andrei Bodnar, Sandeep S. Cranganore and Johannes Brandstetter},
title = {JAX_NR},
year = {2025},
url = {https://github.com/AndreiB137/JAX_NR}
}This code builds upon the BSSN formulation and puncture techniques developed by the numerical relativity community.
Large portions of this implementation are based on or inspired by:
- 20k's C++ Numerical Relativity Code: The core BSSN evolution equations, finite difference operators, and constraint calculations were adapted from this implementation and translated to JAX.
The methods implemented here are based on:
-
T. W. Baumgarte and S. L. Shapiro, Numerical integration of Einstein's field equations, Phys. Rev. D 59, 024007 (1999)
-
M. Campanelli et al., Accurate evolutions of orbiting black-hole binaries without excision, Phys. Rev. Lett. 96, 111101 (2006)
-
J. G. Baker et al., Gravitational-wave extraction from an inspiraling configuration of merging black holes, Phys. Rev. Lett. 96, 111102 (2006)
-
S. Brandt and B. Brügmann, A simple construction of initial data for multiple black holes, Phys. Rev. Lett. 78, 3606 (1997)
-
M. Alcubierre et al., Gauge conditions for long-term numerical black hole evolutions without excision, Phys. Rev. D 67, 084023 (2003) - Gauge conditions and stability
-
Zachariah B. Etienne, Improved Moving-Puncture Techniques for Compact Binary Simulations
-
Hwei-Jang Yo et al., Modifications for numerical stability of black hole evolution
-
James R. van Meter et al., How to move a black hole without excision: gauge conditions for the numerical evolution of a moving puncture
-
Bernd Brugmann et al., Calibration of Moving Puncture Simulations


