Skip to content

Masked Diffusion Training #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,65 @@ class GPTBatch:
sequence_lengths: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None
mask_indexes: torch.Tensor | None = None
mask_probabilities: torch.Tensor | None = None
masked_token_ids: torch.Tensor | None = None


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:

stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
sequence_lengths = None
mask_indexes = None
mask_probabilities = None
masked_token_ids = None

token_ids = torch.from_numpy(stacked_ids)

if sampling_parameters.diffusion.enabled:

diffusion_config = sampling_parameters.diffusion

batch_size, seq_len = token_ids.shape
mask_token_id = diffusion_config.mask_token_id

# Generate a random tensor of batch size to seed masking probabilities
t = torch.rand((batch_size,))

# Compute the mask probabilities for every sequence in the batch
p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon

# Do we need to clamp at max_mask_prob?
# p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob))

# Input has an additional token for shitting, is [0, 1, 2, 3, 4] -> [1, 2, 3, 4]

# index [0, 1, 2, 3, 4, 5] ->
# The labels are already left shifted x = [A, B, C, D, E, F] ->
# embd = [A, B, C, D, E]
# label = [B, C, D, E, F]
# Last input token is dropped from the processing

# Generate random values for all tokens in the batch and only mask the positions\
# where the value is smaller than the mask probability
mask_indexes = torch.rand((batch_size, seq_len)) < p_mask[:, None]

# Need further classification of this padding - 1% data to have partial sequences and padding
# if diffusion_config.pad_prob > 0:
# pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob
# if pad_mask.any():
# mask_indexes[pad_mask] = True

# Replace masked tokens with the mask token ID to create input for the model.
masked_token_ids = torch.where(mask_indexes, mask_token_id, token_ids)

mask_indexes = mask_indexes[:, :-1] # Remove the last token, which is not used for prediction.
mask_probabilities = p_mask

if sampling_parameters.use_loss_masking_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]

stacked_chosen_spans = None
stacked_rejected_spans = None
if sampling_parameters.use_loss_masking_spans:
Expand All @@ -49,12 +102,16 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]

return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
token_ids=token_ids,
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
mask_indexes=mask_indexes,
mask_probabilities=mask_probabilities,
masked_token_ids=masked_token_ids,
)


Expand Down
39 changes: 39 additions & 0 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,40 @@ class ShufflingType(str, enum.Enum):
legacy = "legacy"


@config_class(registry=True)
class DiffusionMaskingConfig(Config):
"""Configuration for diffusion-based masking during data preparation."""

enabled: bool = Field(
default=False, desc="Whether to use masked diffusion during training", hint=FieldHint.feature
)

epsilon: float = Field(
default=1e-3, desc="Minimum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0)
)

max_mask_prob: float = Field(
default=0.15, desc="Maximum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0)
)

pad_prob: float = Field(
default=0.01,
desc="Probability of padding tokens for 1% of samples",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)

mask_token_id: int = Field(default=103, desc="Token ID to use for masking", hint=FieldHint.optional)

def _validate(self) -> None:
super()._validate()
Assert.lt(self.epsilon, self.max_mask_prob) # , "epsilon must be less than max_mask_prob")
Assert.lt(
self.max_mask_prob,
1.0,
)


@config_class()
class GPTSamplingConfig(SamplingConfig):
"""
Expand All @@ -62,6 +96,10 @@ class GPTSamplingConfig(SamplingConfig):
desc="Shuffling strategy.",
hint=FieldHint.feature,
)
diffusion: DiffusionMaskingConfig = Field(
desc="Configuration for diffusion-based masking during data preparation.",
hint=FieldHint.feature,
)


@dataclasses.dataclass(kw_only=True)
Expand All @@ -78,6 +116,7 @@ class GPTSamplingParameters(SamplingParameters):
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
diffusion: DiffusionMaskingConfig


@dataclasses.dataclass(kw_only=True)
Expand Down
3 changes: 1 addition & 2 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
# The name (dict key) is used to insert the weight in the kwargs of the forward pass.
return {}

@property
@abc.abstractmethod
def loss_defs(self) -> list[LossDef]:
def get_loss_defs(self) -> list[LossDef]:
pass

def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def setup(
phase=PhaseType.validation,
)

self._loss_defs = self._multi_stage.base_model.loss_defs
self._loss_defs = self._multi_stage.base_model.get_loss_defs()
self._evaluation_iterator = None
self._is_setup = True

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self._stages: list[Stage] = self._multi_stage.stages
self._tied_parameters = self._multi_stage.tied_parameters
self._num_stages = len(self._stages)
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs}
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.get_loss_defs()}

def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None:
assert not self._is_setup
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, config: TrainerConfig):
multi_stage=self._multi_stage,
distributed_config=self._config.model.distributed,
)
self._loss_defs = self._multi_stage.base_model.loss_defs
self._loss_defs = self._multi_stage.base_model.get_loss_defs()

if not self._is_evaluation_only:
steps_per_split = {
Expand Down
30 changes: 22 additions & 8 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ def _torch_cross_entropy_forward_backward(
grad_output: float | None,
logits_scale_factor: float,
target_format: TargetFormat,
loss_weight: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
A wrapper for the pytorch implementation of cross-entropy.
The cross-entropy kernels themselves are well-optimized, but the need for explicit casting
and separate forward and backward kernels lead to poor performance.
TODO: loss masking only works for with labels format and if the masking index is set to -100.
"""
assert loss_weight is None, "Loss weight not supported in torch cross-entropy implementation."

# Torch compile doesn't understand this.
with torch.set_grad_enabled(grad_output is not None):
logits_ = logits.float().detach().requires_grad_(grad_output is not None)
Expand Down Expand Up @@ -82,6 +85,7 @@ def _fused_cross_entropy_forward_backward(
grad_output: float | None,
logits_scale_factor: float,
target_format: TargetFormat,
loss_weight: torch.Tensor | None,
group: ProcessGroup | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Expand Down Expand Up @@ -143,15 +147,23 @@ def _fused_cross_entropy_forward_backward(
else:
predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True)

per_sample_loss = sum_exp_logits.log() - predicted_logits
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask
per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask

loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
if loss_weight is None:
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask

return loss, grad
loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
return loss, grad
else:
# Weight every token loss by the loss weight. Before averaging.
per_sample_loss = per_sample_loss * loss_weight.flatten()
loss_weight_expanded = loss_weight.reshape(-1, 1)
grad = grad * loss_weight_expanded if grad is not None else None
# Avg across all the tokens.
return per_sample_loss.mean(), grad


_CROSS_ENTROPY_IMPLEMENTATIONS = {
Expand All @@ -170,13 +182,15 @@ def cross_entropy_forward_backward(
implementation: CrossEntropyImpl = CrossEntropyImpl.fused,
logits_scale_factor: float = 1.0,
target_format: TargetFormat = TargetFormat.labels,
loss_weight: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Select the appropriate implementation of cross-entropy.
The triton implementation from the triton submodule is the fastest and recommended one.
It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way,
which is faster and has a relatively small memory overhead.
"""

if target_format == TargetFormat.labels:
Assert.eq(target.shape, logits.shape[:-1])
Assert.eq(target.dtype, torch.int64)
Expand All @@ -193,5 +207,5 @@ def cross_entropy_forward_backward(
)
else:
return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation](
logits, target, loss_mask, grad_output, logits_scale_factor, target_format
logits, target, loss_mask, grad_output, logits_scale_factor, target_format, loss_weight=loss_weight
)
3 changes: 3 additions & 0 deletions fast_llm/functional/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def triton_cross_entropy_forward_backward(
grad_output: float | None,
logits_scale_factor: float,
target_format: TargetFormat,
loss_weight: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes,
Expand All @@ -133,6 +134,8 @@ def triton_cross_entropy_forward_backward(
TODO: Better handling of `grad_output = None`
"""
assert TritonConfig.TRITON_ENABLED
assert loss_weight is None, "Loss weight not supported in triton cross-entropy implementation."

# TODO: Improve assumptions.
assert logits.is_contiguous()
assert target.is_contiguous()
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LanguageModelDimNames:
class LanguageModelLossNames:
language_model_loss = "language_model_loss"
z_loss = "z_loss"
mlm_loss = "masked_language_model_loss"

@staticmethod
def multi_token_prediction_loss(index: int) -> str:
Expand All @@ -38,6 +39,8 @@ class LanguageModelKwargs:
chosen_spans = "chosen_spans"
rejected_spans = "rejected_spans"
loss_mask = "loss_mask"
mask_indexes = "mask_indexes"
mask_probabilities = "mask_probabilities"


@config_class()
Expand Down
84 changes: 84 additions & 0 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,87 @@ def _logits_cross_entropy_forward_backward(
# TODO: de-allocate earlier.
del logits
return loss, output_parallel_linear_backward(grad, context) if self.training else None


class MLMHead(LanguageModelHead):
"""
A masked language model head for diffusion-based training.`
"""

def __init__(
self,
config: LanguageModelBaseConfig,
tensor_space: TensorSpace,
prediction_distance: int,
):
super().__init__(config, tensor_space, prediction_distance)
self._loss_name = LanguageModelLossNames.mlm_loss

def _logits_cross_entropy_forward_backward(
self,
input_: torch.Tensor,
target: torch.Tensor | None,
loss_mask: torch.Tensor | None,
weight: torch.Tensor,
grad_output: float,
kwargs: dict,
losses: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:

assert target is not None, "MLM head requires target labels"
assert loss_mask is None, "MLM head does not support loss mask"

logits, context = output_parallel_linear_forward(
input_=input_,
weight=weight,
bias=None,
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
sequence_parallel=self._sequence_parallel and self._parallel_embeddings,
)

masked_indices = kwargs[LanguageModelKwargs.mask_indexes]
p_mask = kwargs[LanguageModelKwargs.mask_probabilities]
# index [0, 1, 2, 3, 4, 5] ->
# The labels are already left shifted x = [A, B, C, D, E, F] ->
# embd = [A, B, C, D, E]
# label = [B, C, D, E, F]

# Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model,
# can it just learn to copy 3? i.e copy the next token to the masked?
# Yes. We need to drop those position from loss if the next token is not masked
# We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama)

last_weight = 0
B = logits.shape[0]

loss_weight = torch.cat(
(
# ar_weight * in_context[:, 1:] + # not implement yet
masked_indices[:, 1:] / p_mask[:, None],
# + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet
(last_weight * torch.ones(B, device=logits.device)).unsqueeze(1),
# This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later
),
dim=1,
).to(logits.dtype)

loss, grad = cross_entropy_forward_backward(
logits=logits.flatten(0, -2),
target=target,
loss_mask=None,
grad_output=grad_output,
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
implementation=self._cross_entropy_impl,
logits_scale_factor=self._logits_scale_factor,
loss_weight=loss_weight,
)

# This happens with the loss_weight.
# MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274

# Compute per token loss by avg across all tokens in the batch (tokens we ignore are assumed to have a 0 loss still counted towards the average)
# done inside the cross-entropy function
# MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L275

del logits
return loss, output_parallel_linear_backward(grad, context) if self.training else None
5 changes: 5 additions & 0 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ class TransformerConfig(LLMBlockConfig):
" Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.",
hint=FieldHint.expert,
)
diffusion: bool = Field(
default=False,
desc="Use masked-diffusion for training.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
with self._set_implicit_default():
Expand Down
Loading
Loading