Skip to content

Prototype masked diffusion modeling support #208

@tscholak

Description

@tscholak

🎯 Goal (What & Why)

Add support for masked diffusion modelling. This allows training models like LLaDA using a denoising objective over randomly masked spans, enabling both pre- and mid-training. The goal is to de-risk masked diffusion as a modelling strategy and enable conversion from existing auto-regressive (AR) models to diffusion-style training within Fast-LLM.

This feature is a stepping stone toward training and experimenting with diffusion-based language models in a unified framework. It follows the approach outlined in the LLaDA pretraining guidelines.

🚀 Execution Plan

Step 1: What is the smallest working version?

Implement masked diffusion training support:

  • Add support for bidirectional/full attention masks.
  • Add support for uniform BERT-style masking over token sequences.
  • Implement forward diffusion logic, e.g.:
    def forward_process(input_ids, eps=1e-3):
        b, l = input_ids.shape
        t = torch.rand(b, device=input_ids.device)
        p_mask = (1 - eps) * t + eps
        p_mask = p_mask[:, None].repeat(1, l)
        masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
        noisy_batch = torch.where(masked_indices, MASK_TOKEN_ID, input_ids)
        return noisy_batch, masked_indices, p_mask
  • Compute MLM-style loss using, e.g.:
    logits = model(input_ids=noisy_batch).logits
    token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
    loss = token_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
  • Add config flags to control:
    • Use of diffusion loss
    • Masking probability and scheduling
    • Optional padding for 1% of samples

Step 2: What additional optimizations are possible (but optional)?

  • Blocked masked diffusion (partially autoregressive generation).
  • Integrate with dataset packing logic and ensure compatibility with packed, padded, and/or truncated sequences.
  • Add support for loss masking spans to restrict loss computation to selected regions (e.g., exclude prompt tokens).
  • Add classifier-free guidance (CFG) support using mixed conditional/unconditional sequences.
  • Build a LLaDA model converter to load/export HuggingFace checkpoints.

📌 Acceptance Criteria (Must-Haves for Completion)

  • MLM-style objective is implemented with optional padding support (1% configurable).
  • A model can be trained using masked diffusion loss.
  • Feature is documented with config usage and training instructions.
  • PR includes performance/impact summary and conforms to Fast-LLM coding standards.
  • No refactors beyond what is required for this feature.

🛠️ Project Management

  • Assign the project to the Fast-LLM project.
  • Set the Estimate field (in days) in the GitHub project.
  • Use the Size field to categorize the PR size (Medium).
  • Assign an owner when opening the issue.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions