MEG-XL is a pre-trained model for non-invasive electrophysiological brain signals (MEG/EEG). It uses long-context pre-training on MEG to learn contextualised transferable representations, enabling data-efficient fine-tuning for neural decoding. When fine-tuned, MEG-XL achieves state-of-the-art brain-to-text word decoding accuracy while requiring significantly less downstream data than prior approaches.
Paper: arXiv:2602.02494 | Model weights: HuggingFace
If you find this work helpful in your research, please cite the paper:
@article{jayalath2026megxl,
title={{MEG-XL}: Data-Efficient Brain-to-Text via Long-Context Pre-Training},
author={Jayalath, Dulhan and Parker Jones, Oiwi},
journal={arXiv preprint arXiv:2602.02494},
year={2026}
}- Requirements
- Setup
- Quick Start
- Project Structure
- Fine-tuning
- Linear Probing
- Pre-training
- Supported Datasets
- python >= 3.12
- For python packages, see
requirements.txt
- Create and activate a virtual environment with python >= 3.12:
conda create -n megxlenv python=3.12.12
conda activate megxlenv- Install required pip packages:
pip install -r requirements.txt - Download pre-trained MEG-XL weights from HuggingFace
- Follow the notes below depending on how you wish to use MEG-XL
import torch
from brainstorm.neuro_tokenizers.biocodec.model import BioCodecModel
from brainstorm.models.criss_cross_transformer import CrissCrossTransformerModule
# Load tokenizer
tokenizer = BioCodecModel._get_optimized_model()
ckpt = torch.load("brainstorm/neuro_tokenizers/biocodec_ckpt.pt", map_location="cuda")
tokenizer.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in ckpt["model_state_dict"].items()})
tokenizer.eval()
# Load MEG-XL
checkpoint = torch.load("path/to/megxl_checkpoint.ckpt", map_location="cuda")
hparams = checkpoint['hyper_parameters']
model = CrissCrossTransformerModule(
tokenizer=tokenizer,
**hparams
).to("cuda")
# Skip loading RoPE weights (computed deterministically)
state_dict = checkpoint['state_dict']
filtered_state_dict = {}
skipped_rope_keys = []
for key, value in state_dict.items():
if 'rope_embedding_layer.rotate' in key:
skipped_rope_keys.append(key)
else:
filtered_state_dict[key] = value
missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
# Prepare inputs (shapes for 150s segment at 50Hz with 306 MEG channels)
# meg: [batch, channels, time] - raw MEG signal
# sensor_xyz: [batch, channels, 3] - sensor positions (normalized)
# sensor_abc: [batch, channels, 3] - sensor orientations
# sensor_types: [batch, channels] - 0=gradiometer, 1=magnetometer
# sensor_mask: [batch, channels] - 1=valid sensor, 0=padding
# Forward pass (apply_mask=False for inference)
with torch.no_grad():
output = model(meg, sensor_xyz, sensor_abc, sensor_types, sensor_mask, apply_mask=False)
features = output["features"] # [batch, channels, time_tokens, hidden_dim]MEG-XL/
βββ configs/ # Hydra YAML configs for training and evaluation
β
βββ brainstorm/
βββ train_criss_cross_multi.py # Multi-dataset pre-training script
βββ evaluate_criss_cross_word_classification.py # Word classification eval with fine-tuning
βββ evaluate_criss_cross_word_classification_linear_probe.py # Word classification eval with frozen backbone
β
βββ data/
β βββ utils.py # Sensor position normalization utilities
β βββ preprocessing.py # MEG preprocessing (filtering, resampling, caching)
β βββ samplers.py # Recording-level shuffle sampler for efficient I/O
β βββ lightning_datamodule.py # PyTorch Lightning DataModule for single dataset
β βββ multi_datamodule.py # DataModule for multi-dataset pre-training
β βββ multi_dataset.py # Wrapper combining multiple MEG datasets
β βββ subsampled_dataset.py # Wrapper for recording subsampling with sampler compat
β βββ *_dataset.py # Per-corpus dataset implementations
β βββ *_word_aligned_dataset.py # Per-corpus word-aligned segment datasets
β
βββ models/
β βββ criss_cross_transformer.py # Main model with temporal masking and RVQ prediction
β βββ spatial_attention.py # Gaussian Fourier embeddings for 3D sensor positions
β βββ attentional/ # Spatial-temporal attention modules
β
βββ losses/
β βββ contrastive.py # CLIP-style contrastive loss
β
βββ neuro_tokenizers/
βββ biocodec_ckpt.pt # Pre-trained BioCodec checkpoint
βββ biocodec/ # Neural signal tokenizer with RVQ
python -m brainstorm.evaluate_criss_cross_word_classification \
--config-name=eval_criss_cross_word_classification_{armeni,gwilliams,libribrain} \
model.criss_cross_checkpoint=/path/to/your/checkpoint.ckptNotes:
- Requires 1 GPU with >= 80GB VRAM (disable activation checkpointing for faster training if more is available)
- Download the dataset (see Supported Datasets) and update the path in
configs/eval_criss_cross_word_classification_{armeni,gwilliams,libribrain}.yaml - For unsupported datasets, implement a word-aligned data loader following
brainstorm/data/armeni_word_aligned_dataset.py
python -m brainstorm.evaluate_criss_cross_word_classification_linear_probe \
--config-name=eval_criss_cross_word_classification_linear_probe_{armeni,gwilliams,libribrain} \
model.criss_cross_checkpoint=/path/to/your/checkpoint.ckptNotes:
- Requires 1 GPU with >= 40GB VRAM
- See fine-tuning notes above for dataset setup
python brainstorm/train_criss_cross_multi.py \
--config-name=train_criss_cross_multi_50hz_medNotes:
- Requires 1 GPU with >= 80GB VRAM (disable activation checkpointing for faster training if more is available)
- Download the pre-training datasets and update paths in
configs/train_criss_cross_multi_50hz_med.yaml
| Split | Dataset | Link |
|---|---|---|
| Pre-training | CamCAN | mrc-cbu.cam.ac.uk |
| Pre-training | MOUS | data.ru.nl |
| Pre-training | SMN4Lang | OpenNeuro |
| Fine-tuning | MEG-MASC | OSF |
| Fine-tuning | Armeni | data.ru.nl |
| Fine-tuning | LibriBrain | HuggingFace |
We thank the authors of BioCodec for sharing their neural tokenizer code and checkpoint, the authors of BrainOmni for their criss-cross attention implementation, and StΓ©phane d'Ascoli for sharing the D-SigLIP contrastive loss code.
