Skip to content

fix loss masking and padding #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open

fix loss masking and padding #287

wants to merge 21 commits into from

Conversation

sohamparikh
Copy link
Member

@sohamparikh sohamparikh commented Jun 4, 2025

✨ Description

Fixes and improvements for loss masking and padding

Closes #

πŸ” Type of change

Select all that apply:

  • πŸ› Bug fix (non-breaking change that addresses a specific issue)
  • πŸš€ New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • πŸ“ˆ Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • πŸ› οΈ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • πŸ“¦ Dependency bump (updates dependencies, including Dockerfile or package changes)
  • πŸ“ Documentation change (updates documentation, including new content or typo fixes)
  • πŸ”§ Infrastructure/Build change (affects build process, CI/CD, or dependencies)

πŸ“ Changes

List the key changes introduced in this PR:

  • Discard masked tokens while computing loss mean
  • Mask -ve tokens in embeddings
  • Fix cached yaml comparison when truncate_documents=False
  • Handle the edge case when no padding is required in a sequence
  • sum(long_docs_filter) was extremely slow for large datasets

Copy link
Contributor

@tobyzl2 tobyzl2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sohamparikh sohamparikh requested a review from tobyzl2 June 4, 2025 18:06
@@ -525,8 +529,8 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:
elif "unshuffled_tokens" not in data:
# Backward compatibility
# TODO v0.x: Remove
assert self._truncate_documents
data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"]
assert not self._truncate_documents
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That look wrong, old format only supported _truncate_documents=True

assert self._truncate_documents
data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"]
assert not self._truncate_documents
data["unshuffled_tokens"] = data["dataset"]["tokens_per_epoch"] * data["unshuffled_epochs"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the backward compatibility is from before we moved things to dataset, so it was right before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i got a bit confused about the purpose of this. There's still an issue with yaml_data not containing unshuffled_tokens (we can't get it without building the padded cumsum). Pushed a hack to copy it from loaded_yaml_data instead to avoid breaking the flow, not sure if there's a cleaner way

@@ -145,7 +145,7 @@ def _fused_cross_entropy_forward_backward(

per_sample_loss = sum_exp_logits.log() - predicted_logits
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask
per_sample_loss = per_sample_loss[loss_mask]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? loss_mask is an integer so multiplication should work.

Copy link
Member Author

@sohamparikh sohamparikh Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loss is not as interpretable/comparable when we include loss from masked tokens (0) in the average. We start seeing a lot of variance in the reported loss when mixing samples with/without masked tokens

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, so it's to have the right denominator in the mean? Indexing is a bad idea because it introduces a cuda synchronization point (really slow), but you can divide by loss_mask.sum() instead.

Also we probably want to deal with the case loss_mask.sum()==0 .

@@ -99,7 +99,10 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t
input_ = split(input_, group=group, dim=0)
if self._use_absolute_position_embeddings:
position_ids = split(position_ids, group=group, dim=0)
embeddings = torch.embedding(self.word_embeddings_weight, input_)
# mask padded tokens
input_mask = input_ >= 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not to do this unless padding is enabled because of the extra compute involved. Why do we have negative input anyway?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We set the padded tokens to -100 mainly to mask loss on them. Many tokenizers don't have pad tokens either so not straightforward to take it from the config either

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lmk if the change is ok now, or we need to check for padding from a flag (maybe in kwargs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add a flag, something like we already do for labels https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/models/gpt/model.py#L332.
Also I just noticed the method is within @torch.compile, so the overhead shouldn't be too noticeable.

Copy link
Member Author

@sohamparikh sohamparikh Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like this? (truncate_documents might also fit better in the batch config)

@sohamparikh sohamparikh mentioned this pull request Jun 11, 2025
25 tasks
@@ -467,6 +468,12 @@ def __getitem__(self, index: int) -> typing.Any:
else:
# Move on to the next sample.
token_count += padding_size
elif document_size + tokens_in_sample == self._parameters.sequence_length + 1:
if token_count + document_size == token_start:
# Document belongs to the current sample but the condition below will include it for the next sample
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following, why are we ignoring the document if it belongs to the current sample? (Also it clearly belongs to the previous sample)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand seems like in this scenario well have token_start_index_in_document==token_end_index_in_document==document_size , so we'll load 0 tokens from the sample. That seems unnecessary but not wrong, also doesn't seem to relate to document_size + tokens_in_sample == self._parameters.sequence_length + 1?

Seems to me the actual fix would be to replace >= with > in the condition below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, i got confused because i faced this issue in the multimodal branch but it only occurs when there's images right after the text tokens. Will handle it there

@@ -146,8 +146,13 @@ def _fused_cross_entropy_forward_backward(
per_sample_loss = sum_exp_logits.log() - predicted_logits
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask

loss = per_sample_loss.mean()
unmasked_inputs = loss_mask.sum()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still cause a cuda sync. You can just do loss = (per_sample_loss * loss_mask).sum() / torch.maximum(loss_mask.sum(), 1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my own understanding, how can i check whether a pytorch op causes cuda sync?

@@ -75,6 +75,7 @@ class GPTSamplingParameters(SamplingParameters):
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False
cross_document_attention: bool = True
truncate_documents: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need to move this, but if we do we need to add backward compatibility.

@@ -48,14 +48,15 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)
truncate_documents: bool = Field(
default=True,
truncate_documents: bool | None = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works, but we normally do backward compatibility in _from_dict, see example in lines 73-90 below. This one can go in GPTTrainerConfig._from_dict.
Also needs a todo for removal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants