Skip to content

Conversation

@lhallee
Copy link

@lhallee lhallee commented Dec 28, 2024

Currently the test set is much more memory intensive than training and also inaccurate compared to the valid set.

Looks like it has something to do with moving the logits and labels around in cuda memory, as when calling .inference on valid set or test set we need batch_size // 4 to not oom.

Have done experiments trying the validation set with .inference as well and the loss convergence is much worse than using the regular forward pass. The only difference here is that we mask at 15% instead of 20% and are using a batch_size // 4.

So it seems either the input length highly effects performance because of flex attention (which would be a major problem to the usability of the actual model).
or
that training at 20% mask rate and evaluating at 15% leads to much worse performance (this is also not expected vs. normal pLM experiments).

To try and figure this out am considering training an SDPA version and consistent masking rate.

@lhallee lhallee changed the title Investigating why the test set is so memory intensive and inaccurate Investigating why the test set Dec 28, 2024
@lhallee lhallee mentioned this pull request Dec 28, 2024
@lhallee lhallee changed the title Investigating why the test set Investigating the test set Dec 28, 2024
@lhallee
Copy link
Author

lhallee commented Dec 29, 2024

The mask rate discrepancy does not seem to cause it, as training on 15% and eval on 15% still has poor loss validation/test convergence when using .inference instead of forward so I suspect the change in batch_size does. Perhaps sending logits to the cpu in .inference can fix this, but I am worried about the input length changing the performance so much.

@lapp0
Copy link
Collaborator

lapp0 commented Dec 30, 2024

I'm looking into this now, but I see the test loss for ESM2-150 (132M) seems sane? Is the loss still a problem, or is it just a performance issue now? If so, what was the fix?

@lhallee
Copy link
Author

lhallee commented Dec 30, 2024

The 132 run was fine, processed it just like the validation set. But it would be good to return the logits too for more metrics. When returning the logits it takes up more memory so need smaller batch size. With the smaller batch size the metrics and loss are much worse, which is the confusing bit. The sequence length shouldn't affect the performance that much with the document mask...

@lapp0
Copy link
Collaborator

lapp0 commented Dec 30, 2024

Perhaps the issue is truncation of documents hurting performance (1/4th the batch size implies ~4x as many sequences are truncated). During validation we could right-pad the sequences s.t. none are truncated. Let me see if this resolves the problem and prevents any batch-size related issues.

@lhallee
Copy link
Author

lhallee commented Dec 30, 2024

I see. It may be worth it to change the data loading a bit. Thinking about keeping the sequences (or tokens) separated and stacking together up to batch size tokens but not exceeding it so never truncated.

@lapp0 lapp0 mentioned this pull request Dec 30, 2024
@lapp0
Copy link
Collaborator

lapp0 commented Dec 30, 2024

I trained a tiny model with DistributedPaddedDataLoader. The test dataloader batch size is 1/4th that of the validation loader. As expected, the padded test dataloader slightly outperforms the validation dataloader, and massively outperforms the unpadded test dataloader.

step:3901/4000 train_time:1521554ms step_avg:391.04ms
step:4000/4000 val_loss:2.4571 train_time:1560553ms step_avg:391.12ms perplexity:11.6714 param_count:4,953,618
model.safetensors: 100%|____________________________________________| 19.7M/19.7M [00:00<00:00, 26.2MB/s]
model.safetensors: 100%|____________________________________________| 19.7M/19.7M [00:00<00:00, 25.9MB/s]
model.safetensors: 100%|____________________________________________| 19.7M/19.7M [00:00<00:00, 20.3MB/s]
model.safetensors: 100%|____________________________________________| 19.7M/19.7M [00:00<00:00, 20.7MB/s]
peak memory consumption training: 7 GiB
No files have been modified since last commit. Skipping to prevent empty commit.
Test results | Loss: 2.4802 | Perplexity: 11.9438
Test (Padded) results | Loss: 2.4401 | Perplexity: 11.4744
Total train time (min): 26.01
Total train time (hours): 0.43
peak memory consumption testing: 7 GiB

#5

@lhallee lhallee closed this Jan 1, 2025
@lhallee lhallee deleted the investigate_test_set branch June 20, 2025 19:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants