jax_hf provides a JAX implementation of a Hartree–Fock self‑consistent field (SCF) loop on uniform 2D k‑grids, with optional JIT compilation.
- FFT‑based exchange in k‑space
- Dense Hermitian diagonalization (JAX eigh)
- DIIS/EDIIS‑style mixing (robust convergence)
- NumPy/JAX‑friendly API, easy to integrate with other JAX code
Project links:
Users (PyPI):
pip install jax-hfNote: jax_hf depends on JAX. For CPU‑only installs, pip will usually pull a working wheel automatically. For GPU, follow JAX’s official install guide to select the correct extras/wheels for your CUDA/cuDNN stack.
Developers (editable install):
git clone https://github.com/skilledwolf/jax_hf.git
cd jax_hf
pip install -e .import numpy as np
import jax.numpy as jnp
import jax_hf
# Grid and shapes
nk = 128; d = 2
weights = np.ones((nk, nk)) * ((2/nk)*(2/nk) / (2*np.pi)**2) # scalar mesh measure
H = np.zeros((nk, nk, d, d), dtype=np.complex128)
K = np.linspace(-1.0, 1.0, nk)
Vq = (1.0 / np.sqrt((K[:, None]**2 + K[None, :]**2) + 0.1)).astype(np.complex128)[..., None, None]
P0 = np.zeros_like(H)
# Target electron density (half‑filling)
ne_target = 0.5 * d * weights.sum()
# Build HF kernel (JAX arrays inside)
kernel = jax_hf.HartreeFockKernel(
weights, # (nk, nk)
H, # (nk, nk, d, d)
Vq, # (nk, nk, 1, 1) or (nk, nk, d, d)
T=0.5, # temperature
)
# JIT‑compile the SCF iteration function (optional but recommended)
hf_iter = jax_hf.jit_hartreefock_iteration(kernel)
P_fin, F_fin, E_fin, mu_fin, n_iter, history = hf_iter(
P0, float(ne_target),
max_iter=50, comm_tol=1e-3, diis_size=6, log_every=None,
)
print("iters:", int(n_iter), "mu:", float(mu_fin), "E:", float(E_fin))class HartreeFockKernel:
def __init__(self, weights, hamiltonian, coulomb_q, T: float):
...
def hartreefock_iteration(
P0, electrondensity0, hf_step: HartreeFockKernel,
*, max_iter=100, comm_tol=5e-3, diis_size=4, log_every: int | None = 1,
):
"""Runs SCF and returns (P_fin, F_fin, E_fin, mu_fin, n_iter, history)."""
def jit_hartreefock_iteration(hf_step: HartreeFockKernel):
"""Returns a jitted version of hartreefock_iteration with static args."""- shapes:
weightsis (nk, nk),hamiltonianis (nk, nk, d, d),coulomb_qis (nk, nk, 1, 1) or (nk, nk, d, d),P0matches (nk, nk, d, d). - returns: converged density
P_fin, mean‑fieldF_fin, total energy, chemical potentialmu_fin, iteration count, and a smallhistorydict with energy/commutator traces.
Versions are derived from git tags using setuptools_scm. Tags like v1.2.3
produce version 1.2.3; non‑tag builds produce development versions.
BSD 2‑Clause — see LICENSE.
Third‑party notices: see THIRD_PARTY_NOTICES.md.
Author: Dr. Tobias Wolf