Skip to content

Conversation

@lapp0
Copy link
Collaborator

@lapp0 lapp0 commented Dec 30, 2024

As discussed in #3

This removes the issue of a smaller batch size resulting in a larger number of truncations. Should be used for validation / test. Unclear whether it should be used for training - it's more sample efficient, but 4% slower and has fewer tokens per batch.

@lapp0 lapp0 mentioned this pull request Dec 30, 2024
train_esm2.py Outdated
valid_loader = DistributedDataLoader(args.input_valid_bin, batch_size, ddp_rank, ddp_world_size)
test_loader = DistributedDataLoader(args.input_test_bin, batch_size, ddp_rank, ddp_world_size)
test_loader = DistributedDataLoader(args.input_test_bin, batch_size // 8, ddp_rank, ddp_world_size)
test_loader_padded = DistributedPaddedDataLoader(args.input_test_bin, batch_size // 8, ddp_rank, ddp_world_size, eos_id=2, pad_id=1)
Copy link
Collaborator Author

@lapp0 lapp0 Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't hardcode pad / eos IDs, also should simply replace test_loader with test_loader_padded.

dataloading.py Outdated
buf = self.tokens[self.current_position:end_pos]
input_ids = buf.to(device="cuda", dtype=torch.int32, non_blocking=True)
keep = (input_ids == self.eos_id).cumsum(dim=0).argmax().item()
keep = max(keep or 0, self.batch_size - 2048)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max sliding window size - should be a parameter

@lapp0 lapp0 marked this pull request as draft December 31, 2024 17:43
@lapp0
Copy link
Collaborator Author

lapp0 commented Dec 31, 2024

This code has a few bugs. I think it makes more sense to simply update the data packing script so each .bin file ends with EOS, and begins with CLS.

Does that make sense to you?

Separately - could we open issues / discussions pages?

@lhallee
Copy link

lhallee commented Jan 1, 2025

This code has a few bugs. I think it makes more sense to simply update the data packing script so each .bin file ends with EOS, and begins with CLS.

Does that make sense to you?

Yep, just pushed this. Processing the new packed dataset now.

Do you think we should add the padding just for the eval or for training too? Seems like in training we could still use keep to keep track of what full sequences have been input but just don't add the pad tokens (more real tokens better?). For example:

class DistributedDataLoaderTrain(DistributedDataLoader):
    def __init__(self, filename_pattern, seq_len, process_rank, num_processes, eos_id, max_length=1024):
        super().__init__(filename_pattern, seq_len, process_rank, num_processes)
        self.eos_id = eos_id
        self.max_length = max_length

    def reset(self):
        self.current_shard = self.process_rank - self.num_processes
        self.advance()

    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + self.num_processes) % len(self.files)
        self.current_position = 0
        self.tokens = _load_data_shard(self.files[self.current_shard], self.files_num_tokens[self.current_shard])

    def next_batch(self):
        end_pos = self.current_position + self.batch_size
        buf = self.tokens[self.current_position:end_pos]
        input_ids = buf.to(device="cuda", dtype=torch.int32, non_blocking=True)
        keep = (input_ids == self.eos_id).cumsum(dim=0).argmax().item()
        keep = max(keep or 0, self.batch_size - self.max_length)
        # advance current position and load next shard if necessary
        self.current_position += keep
        if self.current_position + self.batch_size >= len(self.tokens):
            self.advance()
        return input_ids


class DistributedDataLoaderEval(DistributedDataLoader):
    def __init__(self, filename_pattern, seq_len, process_rank, num_processes, eos_id, pad_id, max_length=1024):
        super().__init__(filename_pattern, seq_len, process_rank, num_processes)
        self.eos_id = eos_id
        self.pad_id = pad_id
        self.max_length = max_length

    def reset(self):
        self.current_shard = self.process_rank - self.num_processes
        self.advance()

    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + self.num_processes) % len(self.files)
        self.current_position = 0
        self.tokens = _load_data_shard(self.files[self.current_shard], self.files_num_tokens[self.current_shard])

    def next_batch(self):
        end_pos = self.current_position + self.batch_size
        buf = self.tokens[self.current_position:end_pos]
        input_ids = buf.to(device="cuda", dtype=torch.int32, non_blocking=True)
        keep = (input_ids == self.eos_id).cumsum(dim=0).argmax().item()
        keep = max(keep or 0, self.batch_size - self.max_length)
        input_ids[keep + 1:] = self.pad_id
        # advance current position and load next shard if necessary
        self.current_position += keep
        if self.current_position + self.batch_size >= len(self.tokens):
            self.advance()
        return input_ids

@lapp0
Copy link
Collaborator Author

lapp0 commented Jan 2, 2025

I've seen better sample efficiency with padded training, although it was with a low sample size and needs further verification.

I'm thinking we should avoid padding within the packed dataset itself. Correct padding behavior is tied to batch size - if the runtime batch size is smaller than the batch size used for packed dataset padding, we will be truncating. If the runtime batch size is larger, we end up having padding within a sample, which breaks Rotary behavior.

I'm pushing changes to construct the padding at runtime which overrides the advance() method. Additionally it fixes a bug where we train on the last sequence in the test set due to overwritten input_ids var.

You previously stated there were 2,077,660 validation tokens, however when attempting to verify the integrity I see 2,073,745

I see that Synthyra/omg_prot50_packed has been updated since. Can you check whether this number is still accurate? Seems to be a bug in my code, will check why.

train_esm2.py Outdated
test_steps += 1
test_loss += model(input_ids, sliding_window_size)

# TODO: do we need all-reduce?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a bug, without an all_reduce the test set will only show the first GPUs test_loss.

dataloading.py Outdated
# last token has to be EOS or there is no EOS
first_pad_idx = (seq == self.pad_id).nonzero(as_tuple=True)[0][0].item()
assert (seq == self.eos_id).sum() == 0 or seq[first_pad_idx-1] == self.eos_id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove - demonstrates the verification steps taken for loaded sequences.

valid_tokens = 0
with torch.no_grad():
for _ in range(valid_steps):
input_ids = valid_loader.next_batch()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overwriting input_ids results in training on the test set every val step.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just move input_ids = train_loader.next_batch() to right before we pass input ides to model during training right?

@lapp0 lapp0 marked this pull request as ready for review January 2, 2025 19:28
@lhallee
Copy link

lhallee commented Jan 2, 2025

@lapp0 the new shards are in the process of uploading, some error made them fail yesterday.

Overall the code looks good, I think I'll start implementing the metrics again now that we can use reduced batch size and get accurate numbers.

@lhallee lhallee merged commit af0abd6 into master Jan 2, 2025
@lapp0
Copy link
Collaborator Author

lapp0 commented Jan 2, 2025

Thanks!

A new packed data isn't necessary. This code assumes no padding and works with the current packed dataset.

A few more fixes:

  • I figured out the issue with the token counts referenced in my previous comment (I was missing the very first sample). Will be pushing a separate commit fixing this bug.
  • Weighing loss based on non-padded token count. Although it's worth mentioning that in the training set the gradients aren't weighted by their non-pad token count.
  • Don't add 1 to self.local_batch_size, this is an artifact of modded-nanogpt since it's loss objective involves shifting by one token between input and target.

Some tokens are still missing, this is an artifact of modded-nanogpt's behavior. It ignores the last file_token_count % batch_size tokens. This should be fine.

Also worth mentioning the padding ratio for the val set:

2,077,660 tokens before padding, 2,162,688 after.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants