Skip to content

Official implementation of BSSN Binary Black Hole simulation in JAX.

License

Notifications You must be signed in to change notification settings

AndreiB137/JAX_NR

Repository files navigation

JAX_NR

JAX License: MIT

A numerical relativity code written only in JAX.

Overview

JAX_NR implements the BSSN (Baumgarte-Shapiro-Shibata-Nakamura) formulation of Einstein's equations for evolving binary black hole spacetimes. The code leverages JAX's jit compilation and einsum operations.

The code is meant to be a toy data generation for a real binary black hole merger simulation. It does not use AMR, spectral methods, CCE and other highly advanced techniques for very precise gravitational wave extractions. Moreover, the simulation runs on float32 precision, it has a fixed uniform grid covering the whole domain and uses an implicit backward Euler solver instead of RK4 (very common in NR).

As for now, it doesn't support the use of more GPUs or CPU cores.

Key Features

  • BSSN Evolution: Full implementation of the BSSN formulation with two moving punctures.
  • Binary Black Hole Initial Data: Bowen-York puncture initial data for non-rotating black holes.
  • Gravitational Wave Extraction: Weyl scalar $\Psi_4$ computation with spin-weighted spherical harmonic decomposition.
  • Up to date: Using damping parameters and stability constraints from recent NR papers.
  • Constraint Monitoring: Hamiltonian and momentum constraint violation tracking.
  • JAX Acceleration: Full JAX compatibility with jax.jit and jnp.einsum for maximum performance.
  • Fast GPU Solving Time: 30-40min on NVIDIA H200 GPU and a few hours on other GPUs. Up to 1 day or more on CPU.

Installation

1. Clone the repository

git clone https://github.com/AndreiB137/JAX_NR.git
cd JAX_NR

2. Install dependencies

CPU-only installation:

pip install -r requirements.txt

GPU installation (CUDA 13):

pip install -r requirements.txt
pip install -U "jax[cuda13]"

GPU installation (CUDA 12):

pip install -r requirements.txt
pip install -U "jax[cuda12]"

Project Structure

JAX_NR/
├── main.py                          # File to run the simulation
├── run_config.yml                   # Configuration file
├── initial_conditions/
│   ├── __init__.py
│   └── init_cond.py                 # Initial conditions Laplacian solver
├── bssn_evolution/
│   ├── __init__.py                  
│   ├── adm_variables.py             # Converting BSSN variablest to ADM
│    ├── boundary_condition.py.       # Sommerfeld boundary condition
│    ├── bssn.py                      # Evolution equations and solver
│    └── constraint_error.py          # Utils for tracking constraint violations
│  
├── wave_extraction/
│   ├── weyl_scalar_extraction.py    # Ψ₄ computation
│   ├── grid_interpolator_3D.py      # Lagrange interpolation
│   └── spin_weighted_spherical_harmonics.py  # SWSH decomposition
├── finite_difference/               # Finite difference operators and curvature adjusted Kreiss-Oliger dissipation
├── tensor_operations/               # Utils for various tensor operations
└── utils/                           # Small utils for BSSN evolution

Quick Start

1. Configure Your Simulation

Edit run_config.yml to set up your binary black hole parameters:

grid:
  n: 213                    # Grid points per dimension (n × n × n)
  domain_width: 30.0        # Physical domain size in M

initial_conditions:
  bh_mass1: 0.483           # Mass of first black hole
  bh_mass2: 0.483           # Mass of second black hole
  bh_pos1: [0.0, 3.257, 0.0]      # Position of BH1
  bh_pos2: [0.0, -3.257, 0.0]    # Position of BH2
  bh_momentum1: [-0.133, 0.0, 0.0]  # Linear momentum of BH1
  bh_momentum2: [0.133, 0.0, 0.0]   # Linear momentum of BH2

evolution:
  total_time: 250.0         # Total evolution time in M
  dt: 0.0419015             # Time step (should satisfy CFL)
  show_constraints: true    # Monitor constraint violations

damping_args:
  ko_damping: 0.05           # How powerful the KO dissipation should be
  ko_strength: 0.5           # How strong to adjust this to the conformal factor W
  lapse_damping: 0.8         # Slow start lapse condition
  gauge_shift_damping: 2.0   # Damping in the shift evolution
  metric_damping: -0.18      # Damping in the conformal metric, but helpful for the conformal Gamma constraint
  momentum_damping: 0.04     # Damping in the conformal extrinsic curvature evolution for stability in the momentum constraint
  gamma_damping: 0.9         # Damping in the conformal Gamma evolution

output:
  initial_condition_save_dir: ...  # Where to save the initial BSSN variables coming from the solver
  evolution_save_dir: "..."  # Path where the BSSN variables should be saved and used as checkpoint for later if needed
  checkpoint_frequency: 500. # How often to save the BSSN variables
  load_checkpoint: null.     # Only if needed to resume the simulation for later

wave_extraction:
  enabled: true
  extraction_frequency: 2   # How often to extract the SWSH (-2, 2, 2) coefficient of Ψ₄
  save_dir: ...
  extraction_radius: 14.3   # Radius for wave extraction
  n_theta: 512              # Angular resolution
  n_phi: 512

2. Generate Initial Data

The initial data solver uses the puncture method to solve the Hamiltonian constraint.

This will generate the initial condition BSSN variables inside initial_condition_save_dir.

python main.py --config="/path_to/run_config.yml"

3. Run Evolution

python main.py --config="/path_to/run_config.yml"

4. Gravitational Waves

In wave_extraction.save_dir the (-2, 2, 2) coefficient of the Ψ₄ on the sphere with extraction radius given is saved in .npy files. Each one contains a single complex number at a particular step in the simulation.

5. Additional details

  • If using GPU, it requires at least 14 GB of available VRAM.
  • Because the simulation is not using AMR, other initial conditions might not even converge.
  • One of the key aspects that makes this simulation work is the use of automatic differentiation combined with the FD stencils. Instead of doing FD on the conformal Gamma constraint, chain rule is applied until we can use FD derivatives of the BSSN variables. Similarly for ADM variable calculations, where possible the derivatives are propagated by hand until the FD Jacobians or Hessians can be used.
  • It works with float32 and this makes it fast for GPU execution.

Misc

Lapse

Lapse contour slice at $z=0$ in the premerger, close to merger and postmerger events.

t=42M t=146M t=230M
Lapse at t=42M Lapse at t=146M Lapse at t=230M

Gravitational Waves

Psi4 timeseries

Real part of the $\Psi_4$ (-2,2,2) coefficient, showing the gravitational wave signal from the binary black hole merger

Citation

If you use JAX_NR in your work, please cite:

@software{jax_nr,
  author = {Andrei Bodnar, Sandeep S. Cranganore and Johannes Brandstetter},
  title = {JAX_NR},
  year = {2025},
  url = {https://github.com/AndreiB137/JAX_NR}
}

Acknowledgments

This code builds upon the BSSN formulation and puncture techniques developed by the numerical relativity community.

Code Attribution

Large portions of this implementation are based on or inspired by:

  • 20k's C++ Numerical Relativity Code: The core BSSN evolution equations, finite difference operators, and constraint calculations were adapted from this implementation and translated to JAX.

Scientific Background

The methods implemented here are based on:

About

Official implementation of BSSN Binary Black Hole simulation in JAX.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages