Skip to content

FhG-IISB/jNO

jNO logo

Dev Docs Tests Coverage License Citation Docker image available arXiv Paper

Features · Install · Example · Docs · Citation

jNO (jax Neural Operators) is a JAX-native library for training neural operators and physics-informed networks. It unifies data-driven operator regression, residual-based PINN training, mesh-aware FEM/variational PINNs, and foundation-model fine-tuning under one symbolic tracing language — write the PDE once, compile once, train, evaluate, and checkpoint without rewriting the surrounding code.

Note

Research-level repository under active development. The public API is stabilising but may change between minor versions.

What you can do with jNO

Capability Status Notes
Forward PINNs (residual minimisation) stable Hard or soft BC enforcement
Operator learning (DeepONet, FNO, U-Net, PROSE via foundax) stable PDE-residual or data-driven; three architectures showcased
Inverse problems (parameter recovery, surrogate inversion) stable
FEM / Variational PINNs stable TRI3 / TRI6 / QUAD4 elements, weak-form assembly — see known limitations
Adaptive resampling (RAD, RARD, CR3, R3, pinnfluence) stable
Stochastic PDEs and noise nodes (gaussian / uniform / laplace) stable Fokker–Planck, stochastic forcing
Bayesian PINNs (NUTS, HMC, MALA, SGLD, SGHMC, VI) stable 14 worked tutorials
Parameter-efficient fine-tuning (LoRA, DoRA, rsLoRA, PiSSA, VeRA, LoKr, OFT, IA3) stable Chain .lora(...) on any wrapped model
Training explainability (gradient conflict, NTK, Hessian, loss landscape, input sensitivity, residual stats) stable
Foundation-model integration (foundax MLPs, transformers, DeepONet, FNO, PROSE) stable Wrap any Equinox module via jno.nn.wrap(...)
Hybrid data + model parallelism stable jno.core(..., mesh=(batch, model))
W&B logging + Orbax checkpointing stable
IREE / MLIR compiled inference for deployment stable
Hyperparameter / architecture search beta Grid + Nevergrad
Multi-physics coupling beta HyCo patterns developing

~50 worked tutorials across elliptic, parabolic, hyperbolic, coupled, inverse, integral, stochastic, FEM/variational, Bayesian, and operator-learning groups — browse the tutorials index.

Install

pip install "jax-neural-operators[fem]"

GPU support is included by default (jNO depends on jax[cuda]). See the Installation guide for Pixi, Docker, and specific-CUDA-version setups.

Foundation models and other neural operator architectures live in a separate repository (foundax), installed automatically as a dependency so they can also be used on their own.

Example

2-D Poisson with DeepONet — full end-to-end script (click to expand)
import jno
import jax
import optax
import foundax

dir = jno.setup("./runs/test")

# Domain
dom = 500 * jno.domain.rect(mesh_size=0.05, x_range=(0, 2), y_range=(0, 1))
x, y, _ = dom.variable("interior")
xb, yb, _ = dom.variable("boundary")

random_k = jax.random.uniform(jax.random.PRNGKey(0), shape=(500, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", random_k)

# Neural Network
fx = foundax.deeponet(n_sensors=1, coord_dim=2, basis_functions=32, hidden_dim=128, activation=jax.numpy.tanh)
net = jno.nn.wrap(fx)
net.optimizer(optax.adam(learning_rate=optax.schedules.cosine_decay_schedule(init_value=1e-3, decay_steps=20_000, alpha=1e-5)))

# Forward pass and hard enforcement of BCs via output transformation
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.dd(x) + u.dd(y)) + 1.0  # PDE Loss

# Checkpointing (saves every 5000 epochs, keeps best 3)
cb = jno.callbacks.checkpoint(save_interval_epochs=5000, best_fn=lambda m: m["total_loss"])

# Create -> Train -> Save
crux = jno.core(constraints=[pde.mse], domain=dom).print_shapes()
crux.solve(epochs=20_000, batchsize=32, callbacks=[cb]).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")

# Inference via test domain on a finer mesh
tst_dom = 16 * jno.domain.rect(mesh_size=0.01, x_range=(0, 2), y_range=(0, 1))
tst_dom.variable("k", jax.random.uniform(jax.random.PRNGKey(0), shape=(16, 1, 1), minval=0.1, maxval=1.9))

pred, x, y, k = crux.eval([u, x, y, k], domain=tst_dom)
print(pred.shape, x.shape, y.shape, k.shape)

Citation

If you use jNO in academic work, please cite:

@article{armbruster2026jno,
  title   = {jNO: A JAX Library for Neural Operator and Foundation Model Training},
  author  = {Armbruster, Leon and Ramesh, Rathan and Kruse, Georg and Straub, Christopher},
  journal = {arXiv preprint arXiv:2605.10159},
  year    = {2026},
  doi     = {10.48550/arXiv.2605.10159},
  url     = {https://arxiv.org/abs/2605.10159}
}

AI Disclosure

Parts of this codebase — including model ports, tests, and documentation — were developed with the assistance of AI coding tools. All contributions are reviewed and tested to the best of our ability, but mistakes may remain; please open an issue if you spot one.

Packages

 
 
 

Contributors

Languages