Open
Description
🎯 Goal (What & Why)
Add support for masked diffusion modelling. This allows training models like LLaDA using a denoising objective over randomly masked spans, enabling both pre- and mid-training. The goal is to de-risk masked diffusion as a modelling strategy and enable conversion from existing auto-regressive (AR) models to diffusion-style training within Fast-LLM.
This feature is a stepping stone toward training and experimenting with diffusion-based language models in a unified framework. It follows the approach outlined in the LLaDA pretraining guidelines.
🚀 Execution Plan
Step 1: What is the smallest working version?
Implement masked diffusion training support:
- Add support for bidirectional/full attention masks.
- Add support for uniform BERT-style masking over token sequences.
- Implement forward diffusion logic, e.g.:
def forward_process(input_ids, eps=1e-3): b, l = input_ids.shape t = torch.rand(b, device=input_ids.device) p_mask = (1 - eps) * t + eps p_mask = p_mask[:, None].repeat(1, l) masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask noisy_batch = torch.where(masked_indices, MASK_TOKEN_ID, input_ids) return noisy_batch, masked_indices, p_mask
- Compute MLM-style loss using, e.g.:
logits = model(input_ids=noisy_batch).logits token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] loss = token_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
- Add config flags to control:
- Use of diffusion loss
- Masking probability and scheduling
- Optional padding for 1% of samples
Step 2: What additional optimizations are possible (but optional)?
- Blocked masked diffusion (partially autoregressive generation).
- Integrate with dataset packing logic and ensure compatibility with packed, padded, and/or truncated sequences.
- Add support for loss masking spans to restrict loss computation to selected regions (e.g., exclude prompt tokens).
- Add classifier-free guidance (CFG) support using mixed conditional/unconditional sequences.
- Build a LLaDA model converter to load/export HuggingFace checkpoints.
📌 Acceptance Criteria (Must-Haves for Completion)
- MLM-style objective is implemented with optional padding support (1% configurable).
- A model can be trained using masked diffusion loss.
- Feature is documented with config usage and training instructions.
- PR includes performance/impact summary and conforms to Fast-LLM coding standards.
- No refactors beyond what is required for this feature.
🛠️ Project Management
- Assign the project to the Fast-LLM project.
- Set the
Estimate
field (in days) in the GitHub project. - Use the
Size
field to categorize the PR size (Medium). - Assign an owner when opening the issue.