-
Notifications
You must be signed in to change notification settings - Fork 33
fix loss masking and padding #287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
β¦ soham/sft-fixes
fast_llm/data/dataset/gpt/sampled.py
Outdated
@@ -525,8 +529,8 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: | |||
elif "unshuffled_tokens" not in data: | |||
# Backward compatibility | |||
# TODO v0.x: Remove | |||
assert self._truncate_documents | |||
data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] | |||
assert not self._truncate_documents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That look wrong, old format only supported _truncate_documents=True
fast_llm/data/dataset/gpt/sampled.py
Outdated
assert self._truncate_documents | ||
data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] | ||
assert not self._truncate_documents | ||
data["unshuffled_tokens"] = data["dataset"]["tokens_per_epoch"] * data["unshuffled_epochs"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the backward compatibility is from before we moved things to dataset
, so it was right before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, i got a bit confused about the purpose of this. There's still an issue with yaml_data
not containing unshuffled_tokens
(we can't get it without building the padded cumsum). Pushed a hack to copy it from loaded_yaml_data
instead to avoid breaking the flow, not sure if there's a cleaner way
fast_llm/functional/cross_entropy.py
Outdated
@@ -145,7 +145,7 @@ def _fused_cross_entropy_forward_backward( | |||
|
|||
per_sample_loss = sum_exp_logits.log() - predicted_logits | |||
if loss_mask is not None: | |||
per_sample_loss = per_sample_loss * loss_mask | |||
per_sample_loss = per_sample_loss[loss_mask] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change? loss_mask
is an integer so multiplication should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loss is not as interpretable/comparable when we include loss from masked tokens (0) in the average. We start seeing a lot of variance in the reported loss when mixing samples with/without masked tokens
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, so it's to have the right denominator in the mean? Indexing is a bad idea because it introduces a cuda synchronization point (really slow), but you can divide by loss_mask.sum()
instead.
Also we probably want to deal with the case loss_mask.sum()==0
.
@@ -99,7 +99,10 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t | |||
input_ = split(input_, group=group, dim=0) | |||
if self._use_absolute_position_embeddings: | |||
position_ids = split(position_ids, group=group, dim=0) | |||
embeddings = torch.embedding(self.word_embeddings_weight, input_) | |||
# mask padded tokens | |||
input_mask = input_ >= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to do this unless padding is enabled because of the extra compute involved. Why do we have negative input anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We set the padded tokens to -100 mainly to mask loss on them. Many tokenizers don't have pad tokens either so not straightforward to take it from the config either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lmk if the change is ok now, or we need to check for padding from a flag (maybe in kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add a flag, something like we already do for labels https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/models/gpt/model.py#L332.
Also I just noticed the method is within @torch.compile
, so the overhead shouldn't be too noticeable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something like this? (truncate_documents
might also fit better in the batch config)
β¦ soham/sft-fixes
fast_llm/data/dataset/gpt/sampled.py
Outdated
@@ -467,6 +468,12 @@ def __getitem__(self, index: int) -> typing.Any: | |||
else: | |||
# Move on to the next sample. | |||
token_count += padding_size | |||
elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: | |||
if token_count + document_size == token_start: | |||
# Document belongs to the current sample but the condition below will include it for the next sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following, why are we ignoring the document if it belongs to the current sample? (Also it clearly belongs to the previous sample)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I understand seems like in this scenario well have token_start_index_in_document==token_end_index_in_document==document_size
, so we'll load 0 tokens from the sample. That seems unnecessary but not wrong, also doesn't seem to relate to document_size + tokens_in_sample == self._parameters.sequence_length + 1
?
Seems to me the actual fix would be to replace >=
with >
in the condition below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, i got confused because i faced this issue in the multimodal branch but it only occurs when there's images right after the text tokens. Will handle it there
fast_llm/functional/cross_entropy.py
Outdated
@@ -146,8 +146,13 @@ def _fused_cross_entropy_forward_backward( | |||
per_sample_loss = sum_exp_logits.log() - predicted_logits | |||
if loss_mask is not None: | |||
per_sample_loss = per_sample_loss * loss_mask | |||
|
|||
loss = per_sample_loss.mean() | |||
unmasked_inputs = loss_mask.sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still cause a cuda sync. You can just do loss = (per_sample_loss * loss_mask).sum() / torch.maximum(loss_mask.sum(), 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for my own understanding, how can i check whether a pytorch op causes cuda sync?
@@ -75,6 +75,7 @@ class GPTSamplingParameters(SamplingParameters): | |||
use_loss_masking_spans: bool = False | |||
use_preference_loss_spans: bool = False | |||
cross_document_attention: bool = True | |||
truncate_documents: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure we need to move this, but if we do we need to add backward compatibility.
@@ -48,14 +48,15 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): | |||
desc="Multiprocessing context. Do not touch.", | |||
hint=FieldHint.expert, | |||
) | |||
truncate_documents: bool = Field( | |||
default=True, | |||
truncate_documents: bool | None = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works, but we normally do backward compatibility in _from_dict
, see example in lines 73-90 below. This one can go in GPTTrainerConfig._from_dict
.
Also needs a todo for removal.
β¨ Description
Fixes and improvements for loss masking and padding
Closes #
π Type of change
Select all that apply:
π Changes
List the key changes introduced in this PR:
truncate_documents=False
sum(long_docs_filter)
was extremely slow for large datasets