Skip to content

Code for "MEG-XL: Data-Efficient Brain-to-Text via Long-Context Pre-Training"

License

Notifications You must be signed in to change notification settings

neural-processing-lab/MEG-XL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

24 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

MEG-XL

arXiv HuggingFace License

MEG-XL is a pre-trained model for non-invasive electrophysiological brain signals (MEG/EEG). It uses long-context pre-training on MEG to learn contextualised transferable representations, enabling data-efficient fine-tuning for neural decoding. When fine-tuned, MEG-XL achieves state-of-the-art brain-to-text word decoding accuracy while requiring significantly less downstream data than prior approaches.

Paper: arXiv:2602.02494 | Model weights: HuggingFace

MEG-XL Overview

If you find this work helpful in your research, please cite the paper:

@article{jayalath2026megxl,
  title={{MEG-XL}: Data-Efficient Brain-to-Text via Long-Context Pre-Training},
  author={Jayalath, Dulhan and Parker Jones, Oiwi},
  journal={arXiv preprint arXiv:2602.02494},
  year={2026}
}

Table of Contents

Requirements

  • python >= 3.12
  • For python packages, see requirements.txt

Setup

MEG-XL Setup

  1. Create and activate a virtual environment with python >= 3.12:
conda create -n megxlenv python=3.12.12
conda activate megxlenv
  1. Install required pip packages: pip install -r requirements.txt
  2. Download pre-trained MEG-XL weights from HuggingFace
  3. Follow the notes below depending on how you wish to use MEG-XL

Quick Start

import torch
from brainstorm.neuro_tokenizers.biocodec.model import BioCodecModel
from brainstorm.models.criss_cross_transformer import CrissCrossTransformerModule

# Load tokenizer
tokenizer = BioCodecModel._get_optimized_model()
ckpt = torch.load("brainstorm/neuro_tokenizers/biocodec_ckpt.pt", map_location="cuda")
tokenizer.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in ckpt["model_state_dict"].items()})
tokenizer.eval()

# Load MEG-XL
checkpoint = torch.load("path/to/megxl_checkpoint.ckpt", map_location="cuda")
hparams = checkpoint['hyper_parameters']
model = CrissCrossTransformerModule(
    tokenizer=tokenizer,
    **hparams
).to("cuda")
# Skip loading RoPE weights (computed deterministically)
state_dict = checkpoint['state_dict']
filtered_state_dict = {}
skipped_rope_keys = []
for key, value in state_dict.items():
    if 'rope_embedding_layer.rotate' in key:
        skipped_rope_keys.append(key)
    else:
        filtered_state_dict[key] = value
missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
model.eval()

# Prepare inputs (shapes for 150s segment at 50Hz with 306 MEG channels)
# meg: [batch, channels, time] - raw MEG signal
# sensor_xyz: [batch, channels, 3] - sensor positions (normalized)
# sensor_abc: [batch, channels, 3] - sensor orientations
# sensor_types: [batch, channels] - 0=gradiometer, 1=magnetometer
# sensor_mask: [batch, channels] - 1=valid sensor, 0=padding

# Forward pass (apply_mask=False for inference)
with torch.no_grad():
    output = model(meg, sensor_xyz, sensor_abc, sensor_types, sensor_mask, apply_mask=False)
    features = output["features"]  # [batch, channels, time_tokens, hidden_dim]

Project Structure

MEG-XL/
β”œβ”€β”€ configs/                         # Hydra YAML configs for training and evaluation
β”‚
└── brainstorm/
    β”œβ”€β”€ train_criss_cross_multi.py                                # Multi-dataset pre-training script
    β”œβ”€β”€ evaluate_criss_cross_word_classification.py               # Word classification eval with fine-tuning
    β”œβ”€β”€ evaluate_criss_cross_word_classification_linear_probe.py  # Word classification eval with frozen backbone
    β”‚
    β”œβ”€β”€ data/
    β”‚   β”œβ”€β”€ utils.py                     # Sensor position normalization utilities
    β”‚   β”œβ”€β”€ preprocessing.py             # MEG preprocessing (filtering, resampling, caching)
    β”‚   β”œβ”€β”€ samplers.py                  # Recording-level shuffle sampler for efficient I/O
    β”‚   β”œβ”€β”€ lightning_datamodule.py      # PyTorch Lightning DataModule for single dataset
    β”‚   β”œβ”€β”€ multi_datamodule.py          # DataModule for multi-dataset pre-training
    β”‚   β”œβ”€β”€ multi_dataset.py             # Wrapper combining multiple MEG datasets
    β”‚   β”œβ”€β”€ subsampled_dataset.py        # Wrapper for recording subsampling with sampler compat
    β”‚   β”œβ”€β”€ *_dataset.py                 # Per-corpus dataset implementations
    β”‚   └── *_word_aligned_dataset.py    # Per-corpus word-aligned segment datasets
    β”‚
    β”œβ”€β”€ models/
    β”‚   β”œβ”€β”€ criss_cross_transformer.py  # Main model with temporal masking and RVQ prediction
    β”‚   β”œβ”€β”€ spatial_attention.py        # Gaussian Fourier embeddings for 3D sensor positions
    β”‚   └── attentional/                # Spatial-temporal attention modules
    β”‚
    β”œβ”€β”€ losses/
    β”‚   └── contrastive.py  # CLIP-style contrastive loss
    β”‚
    └── neuro_tokenizers/
        β”œβ”€β”€ biocodec_ckpt.pt  # Pre-trained BioCodec checkpoint
        └── biocodec/         # Neural signal tokenizer with RVQ

Fine-tuning MEG-XL for Brain-to-Text

python -m brainstorm.evaluate_criss_cross_word_classification \
    --config-name=eval_criss_cross_word_classification_{armeni,gwilliams,libribrain} \
    model.criss_cross_checkpoint=/path/to/your/checkpoint.ckpt

Notes:

  • Requires 1 GPU with >= 80GB VRAM (disable activation checkpointing for faster training if more is available)
  • Download the dataset (see Supported Datasets) and update the path in configs/eval_criss_cross_word_classification_{armeni,gwilliams,libribrain}.yaml
  • For unsupported datasets, implement a word-aligned data loader following brainstorm/data/armeni_word_aligned_dataset.py

Linear Probing MEG-XL for Brain-to-Text

python -m brainstorm.evaluate_criss_cross_word_classification_linear_probe \
    --config-name=eval_criss_cross_word_classification_linear_probe_{armeni,gwilliams,libribrain} \
    model.criss_cross_checkpoint=/path/to/your/checkpoint.ckpt

Notes:

  • Requires 1 GPU with >= 40GB VRAM
  • See fine-tuning notes above for dataset setup

Pre-training MEG-XL

python brainstorm/train_criss_cross_multi.py \
    --config-name=train_criss_cross_multi_50hz_med

Notes:

  • Requires 1 GPU with >= 80GB VRAM (disable activation checkpointing for faster training if more is available)
  • Download the pre-training datasets and update paths in configs/train_criss_cross_multi_50hz_med.yaml

Supported Datasets

Split Dataset Link
Pre-training CamCAN mrc-cbu.cam.ac.uk
Pre-training MOUS data.ru.nl
Pre-training SMN4Lang OpenNeuro
Fine-tuning MEG-MASC OSF
Fine-tuning Armeni data.ru.nl
Fine-tuning LibriBrain HuggingFace

Acknowledgements

We thank the authors of BioCodec for sharing their neural tokenizer code and checkpoint, the authors of BrainOmni for their criss-cross attention implementation, and StΓ©phane d'Ascoli for sharing the D-SigLIP contrastive loss code.