WIP: MACE foundation-model integration (small / medium / mpa-0 / matpes parity)#558
Draft
PythonFZ wants to merge 102 commits into
Draft
WIP: MACE foundation-model integration (small / medium / mpa-0 / matpes parity)#558PythonFZ wants to merge 102 commits into
PythonFZ wants to merge 102 commits into
Conversation
Design spec for integrating MACE-MP and MACE-MPA foundation models as a native linen descriptor in apax, with an optional apax[mace] extra (e3nn-jax + cuequivariance-jax) and a convert-mace CLI. Fine-tuning reuses existing shallow-ensemble and property-head infrastructure. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Tasked plan covering P0–P6: plumbing, native MACE forward pass in linen, cuequivariance dispatch, torch→apax weight converter + foundation loader, fine-tuning with shallow ensembles, JAX-MD integration, and benchmarks. Companion to spec at docs/superpowers/specs/2026-04-20-mace-foundation-model-integration-design.md. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Converter CLI now accepts canonical model names (e.g. "medium-mpa-0") resolved via mace.calculators.foundations_models.mace_mp(), inheriting upstream's bundled-local / cache / download behavior. - Parity tests compare against the upstream MACECalculator (ASE wrapper) rather than just the raw torch module — matches what downstream users see. - Stress parity + finite-difference force-consistency tests added. - Dev-only `mace-convert` uv dependency group isolates torch/mace-torch from the default apax install. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Implements the smooth polynomial envelope cutoff from Klicpera et al. 2020, used by MACE, as an nn.Module alongside existing BesselBasis/GaussianBasis. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…sk stub Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds MaceModelConfig pydantic class with all MACE-specific fields (r_max, num_bessel, num_polynomial_cutoff, max_ell, hidden_irreps, num_interactions, correlation, interaction_cls, use_cueq) plus foundation-model loading fields (pretrained, freeze_backbone, unfreeze_backbone_epoch). Appends to ModelConfig union alongside existing GMNNConfig, EquivMPConfig, So3kratesConfig. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Narrow pretrained field from Optional[Union[str, Path]] to Optional[str], remove now-unused pathlib import, and update the class docstring to say "model" not "descriptor", use PositiveFloat for r_max type, and use numpy-style parameter formatting without inline-union syntax. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add integration smoke test that trains the MACE P0 skeleton for 1 epoch on MD22 stachyose, verifying the full pipeline (config validation → builder → data pipeline → trainer → checkpoint writing). Also adds "default" to the optimizer partition map so unknown parameter names (e.g. skeleton_w) fall back to nn_lr rather than raising a ValueError. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds apax/layers/descriptor/mace_blocks.py with assemble_edge_features, the first P1 building block for MaceRepresentation: computes Bessel radial basis × polynomial cutoff and spherical harmonics of edge vectors. Covers P1.1 of the MACE foundation-model integration plan.
…ojection, add tests) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wires P1.1–P1.4 blocks into MaceRepresentation: LinearNodeEmbedding -> N x (InteractionBlock -> ProductBlock) -> concat scalars. Also adds force_irreps_out=True to InteractionBlock.skip_linear so the residual sum stays shape-consistent on the first layer, where node_feats is scalar-only and a pure Linear can't reach higher-l channels. Mask helpers are inlined rather than imported from so3krates, which top-level-imports an unavailable myrto dependency in this env. Closes P1 exit criterion — 19/19 descriptor + block + builder tests green.
P1 (native MACE forward pass) complete on feat/mace-foundation-integration through a0c22c3; 19/19 descriptor + block + builder unit tests green. Optional P1.5 Step 3 (mace-jax random-weight parity test) is deferred to P3's parity phase.
Registers the convert-mace CLI command with typer, stubs run_conversion and load_mace_foundation in transfer_learning/mace_foundation.py, and adds unit tests that verify graceful failure when torch is absent and correct argument signature. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…fixes) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… partial) Implements Steps 1, 2a, 3, 5, 6, 7 of P3.2. Adds mace-convert dep group, run_conversion orchestration, all helpers (_load_torch_foundation_model, _extract_config_from_torch, _validate_no_nan, _extract_norm_consts, _torch_mace_version, _apax_version), gated integration tests (3 skip without torch), and _map_state_to_pytree stub with deferred-mapping docstring. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Implement load_mace_foundation() to deserialize a converted .apax/ directory (params.msgpack + config.json) into a flax pytree compatible with MaceRepresentation.init. Adds optional num_elements field to MaceModelConfig and _resolve_short_name() for huggingface_hub fallback. No torch imports in the loader path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
P3.4: Add apax/nn/mace_foundation_model.py skeleton (MaceFoundationEnergyModel, __call__ raises NotImplementedError pending P3.2 Step 4); add Notes section to MaceRepresentation docstring flagging future return_per_layer_node_feats flag; create gated integration tests in test_mace_parity.py (all skip without torch). P3.5: Add _is_mace_foundation_dir helper and early NotImplementedError guard in ASECalculator.__init__; add tests/unit_tests/md/__init__.py and gated unit test that verifies the error is raised with the correct hint (1 test, passes). P3.6: Enhance _extract_config_from_torch head validation to raise ValueError with clear guidance when an unknown head name is requested on a multi-head model; append test_convert_rejects_unknown_head to test_convert.py (skips without torch). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Flip P3.1/P3.3 and the scaffolded steps of P3.2/P3.4/P3.5/P3.6 to [x]. Annotate the four deferred boxes (P3.2 Step 4 map body; P3.4 Steps 2/3/6 per-layer feats, foundation-model __call__, parity iteration) with a pointer to the `uv sync --group mace-convert --extra mace` workflow that unblocks them. Update the Progress preamble with the four P3 scaffolding commits and the current gap.
…ftMACE Foundation models expose most hyperparameters only through submodules rather than top-level attributes. Walk the submodule tree to read max_ell, hidden_irreps, correlation, num_elements, num_bessel, num_polynomial_cutoff, scale/shift, atomic_numbers, and the interaction variant. Add ``test_extract_config_from_torch_small`` asserting these match the known MACE-MP-0 small config. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… parity) Add ``LinearReadoutBlock``, ``NonLinearReadoutBlock``, and ``ScaleShift`` mirroring torch-mace's primitives. ``NonLinearReadoutBlock`` multiplies SiLU output by torch-e3nn's ``normalize2mom`` constant for bit-for-bit parity. Unit tests assert shape/finite and affine behaviour. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Implement the full forward pass mirroring torch-mace's ScaleShiftMACE for single-head, scalar-only-hidden foundation models (MACE-MP-0 small): - InteractionBlock gains a ``foundation_mode`` that uses a per-element ``PerElementSkipTP`` (the torch ``skip_tp``), divides the conv message by ``avg_num_neighbors``, and disables the radial-MLP output activation. - ProductBlock gains ``input_irreps`` (separate from ``hidden_irreps``) and an optional ``post_linear`` to match torch's Equivariant product basis layout. - New ``_scalar_irreps_only`` / ``_default_target_irreps`` helpers. - ``MaceFoundationEnergyModel.__call__`` composes these plus the readout/ scale-shift blocks to emit a per-atom energy. Maps input Z through the model's own atomic-numbers table (required by the upstream indexing convention) and flips dr_vec sign to match torch's edge convention. Smoke test asserts finite output shape ``(n_atoms,)``. Parity mapper and end-to-end test follow in subsequent commits. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
apax ZBL applies one output_scale; torch-mace folds ZBL inside scale_shift. For per-element scale_shift.scale tensors, naively casting to float() silently keeps only the first element (or raises an unhelpful torch RuntimeError). Raise NotImplementedError instead so the failure is loud.
- Drop internal task IDs (I7) from comment + test docstring. - Move pytestmark to module scope (matches T8 sibling). - Error message now points users to "open an issue" instead of just saying "not supported." Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
T4 added ``self.sow("debug", ...)`` calls in apax/nn/models.py and
the MACE blocks/readout/basis-functions/empirical layers so the
layer-by-layer parity harness could harvest intermediates via
``apply(..., mutable=["debug"])``. Production paths in
train/checkpoints.py and transfer_learning/mace_foundation.py were
gated with ``mutable=DenyList("debug")`` to keep the collection out
of the params pytree at init.
But every direct ``model.init(...)`` caller (test helpers in
tests/conftest.py:initialize_model, the in-test ``model.init`` in
tests/integration_tests/md/test_md.py + test_ase_hessian.py, and
any user driving Flax's public model API) bypassed those gates.
Flax's ``init`` defaults to making *all* collections mutable, so
the ``debug`` branch landed under ``params``; downstream
``jax.grad`` / ``jax.vmap`` then leaked ``LinearizeTracer`` /
``BatchTracer`` values that orbax couldn't serialise.
Rather than thread ``DenyList("debug")`` through every external
caller, gate the sow calls themselves with a process-wide flag in
apax/utils/parity_debug.py. The harness opens the gate via the
``parity_debug()`` context manager before its
``apply(..., mutable=["debug"])`` call; everywhere else the sows
are no-ops and ``model.init`` returns a clean params pytree.
Failing tests fixed: test_md::test_run_md, test_md::test_ase_calc,
test_model_loading::test_model_loading,
test_model_loading::test_moved_model_loading,
test_ase_hessian::test_ase_hessian, and all 6
test_api::test_kernel_selection variants.
S22 parity unaffected (MAE 0.000425 meV, well below 1e-5 eV gate).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Restored from pre-merge local additions on this branch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the prior MACE-foundation-remediation spec/plan (deleted) with a single design covering the schema re-partition into nested basis / radial_embedding / descriptor / readout groups, the discriminated ``interactions`` typed list (drops ``interaction_cls`` and ``num_interactions``), and the simultaneous Linen-level refactor of ``MaceRepresentation`` to consume a pre-built ``radial_embedding`` submodule, structurally matching how the other apax descriptors take their basis. Branch is pre-merge so no compat shim. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…e-partition design Adds a ``variant: kocer | standard`` field to ``BesselBasisConfig`` so ``ModelBuilder.build_basis_function`` becomes the single dispatch point for all models including MACE, defaulting to ``kocer`` (preserves existing GMNN/EquivMP/So3krates configs) and overridden to ``standard`` by ``MaceModelConfig``. ``MaceRadialEmbedding`` now consumes a pre-built ``basis_fn``. Documents the load-bearing Linen placement of ``distance_transform`` (config-level grouped under radial_embedding, module-level a field on MaceRepresentation) so the converter slot path ``representation/distance_transform/...`` stays stable. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ction dispatch Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…interactions list Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…mbedding/descriptor/readout Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_blocks comment to point at the new MaceDescriptorConfig.interactions discriminator Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1e1667c to
d06f73f
Compare
…d design Three follow-ups to PR #558: (a) loud error on transfer-learning shape mismatch with copy-pasteable reset_layers suggestion + full-leaf-path support, (b)+(c) explicit per-property-head readout kind with cross-model error guards, plus a template comment update describing the run-once-paste-reset_layers workflow. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… leaf paths in reset_layers Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…MACE property heads Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…riminator Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Task 4's integration tests revealed that e3nn_jax's Linear layer encodes the output irreps in the parameter dict key string, so n_shallow_ensemble width changes produce sibling orphan leaves rather than same-path shape changes. Task 1's same-path shape check therefore never fires for the canonical foundation→ensemble workflow. Task 1.5 adds a structural-mismatch pass that groups leaves by parent path, detects parents with orphans on BOTH sides, and emits target paths in the suggested reset_layers snippet. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…nn-named keys Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…smatch workflow Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…_finetune_minimal Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously the parent ModelBuilder guard (GMNN/EquivMP/So3krates + kind='mace') explained why it errored but did not tell the user how to recover. Match the tone of the sibling MaceBuilder guard (which tells MACE+kind='standard' users to set kind='mace' or switch model) by suggesting kind='standard' or a model change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sweep of branch review feedback. Removes anything that was speculative,
shimmed, or duplicated; tightens the parts that remain.
Workarounds replaced with real fixes:
- delete apax/utils/parity_debug.py and scripts/mace_layer_parity.py;
remove all sow("debug", ...) gates and DenyList("debug") workarounds
- _map_pair_repulsion / _map_distance_transform dispatch off
config.<...>.trainable instead of try-buffers-then-params; wire
_map_distance_transform into _map_state_to_pytree (was unreachable)
- _map_readouts iterates linear slot keys via _scatter_o3_linear_blocks
instead of literal "w[0,0] 128x0e,1x0e" — works for non-MP-0 widths
- check_for_ensemble filters by "params" collection instead of treating
0-d leaves as size-1
- AgnesiTransform / MaceZBLPairRepulsion rewritten as @nn.compact so all
variables resolve through .value (drops hasattr-based _scalar helper)
Dead code / placeholders dropped (no forward-compat shims):
- model.descriptor.use_cueq (raised NotImplementedError; reserved for P2)
- model.readout.kind and PropertyHead.kind (architecture follows the
parent model type; "standard" energy-head fallback was a footgun)
- InteractionBlock = InteractionBlockResidual back-compat alias
- assemble_edge_features shim and its tests
- M parameter on _scatter_o3_linear_blocks (noqa-marked back-compat)
- InteractionBlockDensity.hidden_irreps "for parity" field
- layer_idx fields on all three interaction blocks + ProductBlock
- **kwargs swallowers in simulate.py energy wrappers
- build_hessian_neighbor_fns one-line alias
- Optional[List]/Optional[dict] widening + normalize_null validator on
TransferLearningConfig, OptimizerConfig.kwargs, LossConfig.parameters,
DataConfig.{shift,scale}_options
- getattr(self.model, "disable_cell_list", False) — disable_cell_list
and nl_skin lifted to ApaxBase
Refactors:
- _build_mace_readout helper unifies energy + property head construction
- ModelBuilder.build_readout no longer knows about MACE
- _path_to_str helper replaces 5+ duplicated path-stringifying calls in
parameter_transfer.py
- BOHR_RADIUS_ANG / COULOMB_EV_ANG / DR_FLOOR hoisted to module level
- except Exception narrowed to (ImportError, AttributeError) /
(AttributeError, RuntimeError) in version probes and norm-const probe
- loss functions: drop mutable {} defaults, parameters/inputs are now
required positional args; mass_weighted_hessian_loss gets a real
numpy-style docstring
Comment / docstring trims throughout; removed history-narrating prose
that documented prior refactors. All 142 unit tests pass.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Layer-by-layer parity harness consumed via the parity_debug sow mechanism that was removed; the script is no longer functional. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The MACE foundation converter was pulling ``mace_jax.adapters.cuequivariance.symmetric_contraction._convert_native_weights`` through the ``mace-convert`` extra. mace-jax is not on PyPI, was pinned by git URL, and the function we needed was private. The dense linear-algebra work is straightforward to reproduce directly on the cuequivariance descriptor that apax already builds, so vendor it locally. - new apax/transfer_learning/torch_sc_adapter.py with a public ``convert_native_weights`` entry point. Reduced + full-CG paths both supported. Full-CG transform's design-matrix solve replaces mace-jax's nnx ``SymmetricContraction`` with a direct ``cuex.equivariant_polynomial`` application of the same descriptor. - mace_foundation._map_products now imports from the local adapter. - pyproject: remove mace-jax + its [tool.uv.sources] git pin; promote mace-convert from [dependency-groups] to [project.optional-dependencies] so ``pip install apax[mace-convert]`` works. - rename in-package private symbol _INTERACTION_BLOCK_CLS -> INTERACTION_BLOCK_CLS (consumed by MaceRepresentation and tests). Verified: 142 unit tests pass; all 3 mace_parity converter tests pass against MACE-MP-0 small. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… like MPA-0 The converter wired ``_map_distance_transform`` to look at ``representation/distance_transform/...``, but the AgnesiTransform is a field of MaceRadialEmbedding and is called from inside its forward pass — not from MaceRepresentation. Linen materialises the slot under the parent that actually invokes it, so the real path is ``representation/radial_embedding/distance_transform/...``. The dead ``distance_transform: Any`` field on MaceRepresentation never created a slot (linen only registers submodules that are called); it was stale state from an earlier design. Drop the field, fix the converter path, update the tests that used to pass it. Regression test added: ``test_convert_medium_mpa0_with_distance_transform`` runs the full converter against MACE-MPA-0 medium and asserts that the AgnesiTransform's ``a`` / ``q`` / ``p`` / ``covalent_radii`` are present in the converted pytree with values matching the torch source. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Warning
Claude Code against
mace-torchandmace-jax. I have not written this code!Summary
Integrates torch-mace foundation models into apax via a new
apax convert-macepath. End-to-end parity (withinrtol 1e-4energy /rtol 1e-3force) on:MACE-MP-0 small(Residual, scalar-only)MACE-MP-0 medium(Residual, multi-irrep)MACE-MPA-0 medium(Density variants + ZBL + AgnesiTransform)MACE-matpes-r2scan-omat-ft(same architecture as mpa-0)55 commits ahead of
main. Marked draft / WIP for review while the broader feature branch settles; the most recent commit (34188618) closes the mpa-0 / matpes parity gap.What's in the latest commit
Two design specs landed in tandem:
docs/superpowers/specs/2026-05-04-mace-density-zbl-design.mddocs/superpowers/specs/2026-05-04-mace-distance-transform.mdArchitecture
_interaction_scaffoldhelper + three interaction blocks (Residual / Density / DensityResidual) with per-layer dispatch inMaceRepresentation.MaceZBLPairRepulsionempirical correction (faithful port ofmace.modules.radial.ZBLBasis); newoutput_scalefield reproduces torch's "ZBL insidescale_shift" semantics without mutating buffers.AgnesiTransformLinen module +MaceRadialEmbeddingrefactor that threadsZ/idxthrough and applies the transform between cutoff and bessel.MaceModelConfig.interaction_clswidened to a Literal union (single str OR per-layer list); newdistance_transformdiscriminated union;MaceZBLPairRepulsioncorrection config.Converter
density_fnweight scatter; size-aware_scatter_o3_linear_blocksfor non-uniform layer-1linearshapes; per-pathskip_tpscatter._map_pair_repulsionand_map_distance_transformhelpers respect bothparams(trainable) andbuffers(fixed) collections._map_scale_shiftcollapses 2-D atomic_energies tables (matpes ships(1, 89)even single-head).Bug fixes worth calling out
check_for_ensemblecrashed on 0-d scalar buffers — treatsndim==0as size 1 (preserves ensemble semantics;jnp.stacklifts every leaf to>=1-D).AgnesiTransformclipsrandr_0to 0.02 so masked / padding edges withdr=0don't triggerinfgradients viax^(-3.66).[14, 14, 14]to avoid an unrelated apax PBC bug (NaN forces from self-image edges) that exists for every foundation; documented in fixture docstring.Test plan
uv run pytest tests/unit_tests/ --ignore=tests/unit_tests/md— 93 pass, 0 fail.uv run pytest tests/integration_tests/mace -m mace_parity— 17 pass (small / medium / mpa-0 water; matpes periodic SiO₂; ZBL dimer parity atrtol 1e-12).output_scalesemantics make sense for fresh-trained apax models (default1.0= no scaling)._map_scale_shift2-D atomic_energies handling vs. multi-head matpes when an explicit head is passed.Out of scope
RealAgnostic(non-residual non-density) interaction variant — not used by foundations in scope; additive when a model needs it.SoftTransform(the tanh-based distance transform) — additive one-class extension.🤖 Generated with Claude Code