Skip to content

fix loss masking and padding #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)
truncate_documents: bool = Field(
default=True,
truncate_documents: bool | None = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works, but we normally do backward compatibility in _from_dict, see example in lines 73-90 below. This one can go in GPTTrainerConfig._from_dict.
Also needs a todo for removal.

default=None,
desc=(
"Please use batch.truncate_documents instead "
"If enabled, documents may be truncated while being packed to fit the sequence length."
"Otherwise, sequences will be padded such that every document lies entirely within a sample"
" (and documents exceeding the sequence length will be skipped altogether)."
),
hint=FieldHint.feature,
hint=FieldHint.deprecated,
)

def _validate(self) -> None:
Expand Down
1 change: 0 additions & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def setup(
distributed=distributed,
dataset_name=dataset_name,
tokenizer=self._tokenizer,
truncate_documents=self._config.truncate_documents,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class GPTSamplingParameters(SamplingParameters):
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False
cross_document_attention: bool = True
truncate_documents: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need to move this, but if we do we need to add backward compatibility.

# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
Expand All @@ -90,7 +91,6 @@ class GPTSamplingData(SamplingData):
config: GPTSamplingConfig
parameters: GPTSamplingParameters
tokenizer: "Tokenizer"
truncate_documents: bool = True


@config_class(registry=True)
Expand Down
15 changes: 8 additions & 7 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
self._indexed_dataset = indexed_dataset
self._config = sampling.config
self._parameters = sampling.parameters
self._truncate_documents = sampling.truncate_documents
self._truncate_documents = sampling.parameters.truncate_documents
self._device = torch.device("cuda" if self._config.gpu else "cpu")

if sampling.cache_directory is None:
Expand Down Expand Up @@ -144,7 +144,7 @@ def _sample(self) -> None:
" Please make sure Fast-LLM is installed correctly."
)
long_docs_filter = document_sizes > self._parameters.sequence_length + 1
ignored_documents = sum(long_docs_filter)
ignored_documents = long_docs_filter.sum().item()
if ignored_documents:
log_main_rank(
f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.",
Expand Down Expand Up @@ -201,9 +201,10 @@ def _sample(self) -> None:

if self._yaml_path is not None and self._yaml_path.is_file():
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
# Hack to make sure unshuffled tokens are loaded
if not self._truncate_documents:
yaml_data["unshuffled_tokens"] = loaded_yaml_data["unshuffled_tokens"]
self._load_yaml_data(yaml_data)
if not self._truncate_documents and not self._parameters.use_preference_loss_spans:
del loaded_yaml_data["unshuffled_tokens"]

if loaded_yaml_data != yaml_data:
raise RuntimeError(
Expand Down Expand Up @@ -469,7 +470,7 @@ def __getitem__(self, index: int) -> typing.Any:
token_count += padding_size

# Determine if the document belongs to the requested sample.
if token_count + document_size >= token_start:
if token_count + document_size > token_start:
# Determine which part of the document belong to the sample, and add it to the list.
token_start_index_in_document = max(token_start - token_count, 0)
token_end_index_in_document = min(token_end - token_count, document_size)
Expand All @@ -487,7 +488,7 @@ def __getitem__(self, index: int) -> typing.Any:
0,
self._parameters.sequence_length + self._parameters.extra_tokens,
)
if span[1] > span[0]:
if span[1] >= span[0]:
loss_masking_spans.append(span)

# Go to the next document.
Expand Down Expand Up @@ -547,7 +548,7 @@ def __init__(
):
assert isinstance(sampling, GPTSamplingData)
self._indexed_dataset = indexed_dataset
if not sampling.truncate_documents:
if not sampling.parameters.truncate_documents:
raise NotImplementedError(
"Legacy sampling only supports document truncation. Please use the latest dataset format."
)
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def _fused_cross_entropy_forward_backward(

per_sample_loss = sum_exp_logits.log() - predicted_logits
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask

loss = per_sample_loss.mean()
loss = (per_sample_loss * loss_mask).sum() / torch.maximum(loss_mask.sum(), 1)
else:
loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)

Expand Down
1 change: 1 addition & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class LanguageModelKwargs:
chosen_spans = "chosen_spans"
rejected_spans = "rejected_spans"
loss_mask = "loss_mask"
mask_inputs = "mask_inputs"


@config_class()
Expand Down
16 changes: 13 additions & 3 deletions fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
)

@torch.compile
def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor:
def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor:
Assert.eq(position_ids is not None, self._use_absolute_position_embeddings)
group = self._tensor_space.distributed.tensor_group
if self._parallel_embeddings:
Expand All @@ -101,9 +101,17 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t
input_ = split(input_, group=group, dim=0)
if self._use_absolute_position_embeddings:
position_ids = split(position_ids, group=group, dim=0)
embeddings = torch.embedding(self.word_embeddings_weight, input_)
# handle masked tokens
if mask_inputs:
input_mask = input_ >= 0
masked_input = input_ * input_mask
embeddings = torch.embedding(self.word_embeddings_weight, masked_input)
else:
embeddings = torch.embedding(self.word_embeddings_weight, input_)
if self._use_absolute_position_embeddings:
embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight)
if mask_inputs:
embeddings = embeddings * input_mask.unsqueeze(2)
with set_generator(
self._tensor_space.distributed.tp_generator
if self._sequence_parallel
Expand All @@ -125,4 +133,6 @@ def forward(
tensor_name="Embedding output",
dtype=self._residual_dtype,
)
return self._forward(input_, kwargs.get(LanguageModelKwargs.position_ids))
return self._forward(
input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.mask_inputs)
)
18 changes: 18 additions & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import logging
import typing

from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
Expand All @@ -17,6 +18,8 @@
from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel
from fast_llm.models.gpt.trainer import GPTTrainer

logger = logging.getLogger(__name__)


class GPTHuggingfaceCheckpointFormat(CheckpointFormat):
support_optimizer: typing.ClassVar[bool] = False
Expand Down Expand Up @@ -91,6 +94,15 @@ class GPTBatchConfig(BatchConfig):
desc="Read loss masking spans from the dataset.",
hint=FieldHint.feature,
)
truncate_documents: bool | None = Field(
default=True,
desc=(
"If enabled, documents may be truncated while being packed to fit the sequence length."
"Otherwise, sequences will be padded such that every document lies entirely within a sample"
" (and documents exceeding the sequence length will be skipped altogether)."
),
hint=FieldHint.feature,
)

def _validate(self) -> None:
if self.micro_sequence_length is None:
Expand Down Expand Up @@ -183,6 +195,12 @@ def _validate(self) -> None:
self.batch.sequence_length = self.model.base_model.max_position_embeddings
if self.model.base_model.use_megatron_initialization:
set_megatron_distributed_seeds(self.model.distributed)
if self.data.truncate_documents is not None:
logger.warning(
"Using deprecated field `data.truncate_documents`, `batch.truncate_documents` will be overridden if specified. "
"Use `batch.truncate_documents` instead."
)
self.batch.truncate_documents = self.data.truncate_documents
super()._validate()

if self.model.base_model.use_absolute_position_embeddings:
Expand Down
7 changes: 4 additions & 3 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def preprocess_meta(
TransformerKwargs.hidden_dims: hidden_dims,
TransformerKwargs.sequence_length: sequence_length,
TransformerKwargs.sequence_q_dim: sequence_q_dim,
LanguageModelKwargs.mask_inputs: not batch_meta.truncate_documents,
}

sequence_k_pasts = range(
Expand Down Expand Up @@ -302,7 +303,7 @@ def preprocess(
if batch.loss_masking_spans is not None:
# avoid changing input tokens
labels = labels.clone()
for i, spans in enumerate(batch.loss_masking_spans):
for idx, spans in enumerate(batch.loss_masking_spans):
if not spans.numel():
continue
valid_spans = spans[
Expand All @@ -316,9 +317,9 @@ def preprocess(
loss_mask = torch.ones_like(labels, dtype=torch.bool)
for start, end in valid_spans:
if sequence_first:
loss_mask[start : end + 1, i] = False
loss_mask[start : end + 1, idx] = False
else:
loss_mask[i, start : end + 1] = False
loss_mask[idx, start : end + 1] = False
if self._config.distillation_model is not None:
kwargs[LanguageModelKwargs.loss_mask] = loss_mask
labels = torch.where(loss_mask, labels, -100)
Expand Down
1 change: 1 addition & 0 deletions fast_llm/models/gpt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _get_sampling_parameters(
"use_loss_masking_spans": self._config.batch.use_loss_masking_spans,
"use_preference_loss_spans": self._config.model.base_model.enable_dpo,
"cross_document_attention": self._config.batch.cross_document_attention,
"truncate_documents": self._config.batch.truncate_documents,
"extra_tokens": self._config.model.base_model.prediction_heads,
}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/data/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def get_sampling_data(
num_samples=num_samples,
sequence_length=sequence_length,
vocab_size=vocab_size,
truncate_documents=truncate_documents,
),
cache_directory=cache_directory,
distributed=distributed,
dataset_name=phase.value,
tokenizer=tokenizer,
truncate_documents=truncate_documents,
)


Expand Down