Skip to content

Conversation

@lapp0
Copy link
Collaborator

@lapp0 lapp0 commented Dec 23, 2024

Changes

  • Manually prefix documents with BOS token ID, and use BOS token to determine document mask
  • Change tokens from uint8 -> uint16 to resolve a bug in packing data
  • Introduce MLM loss, masking a single token per document
    • Reimplement MLM loss per ESM2 paper (12% masked, 1.5% replaced randomly, 1.5% kept)
  • Get python data/cached_omgprot50.py 10 working

Baselines

  • Upper Bound: Random Model: 33 tokens -> perplexity of 33
  • Lower bound: It is estimated that the shannon entropy of protein sequences is 2.5 bits / amino (CE loss of 1.73, perplexity of 5.66)
  • This PR: achieves CE loss of 2.48, perplexity of 11.94
    • trained on 4x4090
    • 85M parameters, 6,000 steps @ batch size = 524,288
  • ESM2:
    image

@lapp0
Copy link
Collaborator Author

lapp0 commented Dec 23, 2024

@lhallee where did you get these numbers and how is loss measured? Per the paper cited in the OP (in the "lower bound" bullet), it seems the theoretical lower bound for cross entropy loss is ~2.5 (perplexity of 5.66).

https://github.com/Synthyra/SpeedRunningESM2/blob/master/README.md?plain=1#L4-L11

@lhallee
Copy link

lhallee commented Dec 23, 2024

Hi @lapp0. Really awesome work, thanks for your continued effort!

I got the loss from the script I wrote for the benchmark here.
I had it masking like BERT-base, which I believe is how ESM2 was trained. Where did you see 12% and 1.5%, etc.?

https://www.biorxiv.org/content/10.1101/2022.07.20.500902v2.full.pdf
image

What is the motivation behind the BOS token? ESM has a CLS and EOS token which should suffice right? Also, curious behind the single token being masked per document. Since ESM2 the field has somewhat converged on 20% random masking with no replacement which may work better.

@lapp0
Copy link
Collaborator Author

lapp0 commented Dec 23, 2024

12% and 1.5% are just the resolved probabilities of what you quoted

"Each token has a 15% probability of inclusion. If included the tokens have an 80% probability of being replaced with a mask token, a 10% probability of being replaced with a random token"

Perhaps part of the issue is that I don't incorporate the 1.5% chance of "being replaced with an unmasked token". Will try incorporating that.

What is the motivation behind the BOS token? ESM has a CLS and EOS token which should suffice right?

CLS should suffice, I missed that the tokenizer was producing this since your dataloader incorporates the add_special_tokens=True argument.

Also, curious behind the single token being masked per document. Since ESM2 the field has somewhat converged on 20% random masking with no replacement which may work better.

This was based on reading on a different ESM variant, however I've reverted the change to reflect the ESM2 paper (12%, 1.5%, 1.5%)

@lhallee
Copy link

lhallee commented Dec 24, 2024

theoretical lower bound for cross entropy loss is ~2.5 (perplexity of 5.66).

With exp^loss for perplexity looks like the lower bound is around 1.73, which makes more sense in the context of the benchmarked results and other training experiments. Seeing lower than CE 2.5 quite often and sub 2 sometimes, I suppose it also depends on the mask percentage.

12% and 1.5% are just the resolved probabilities of what you quoted

That makes sense!

85M parameters, 1,480 steps @ batch size = 524,288

We may get more bang for our buck by maximizing steps instead of batch size. Not sure though.

@lapp0
Copy link
Collaborator Author

lapp0 commented Dec 24, 2024

With exp^loss for perplexity looks like the lower bound is around 1.73

Checks out to me.

We may get more bang for our buck by maximizing steps instead of batch size. Not sure though.

This is modded-nanogpt batch size. The Muon optimizer docstring suggests it is only tested on very large batch sizes, although it's worth experimenting with smaller sizes.

This PR should be good to try out. Let me know if the val_loss metric aligns with your validation set loss.

https://gist.github.com/lapp0/e076d696df147c7df8028cb2069300d4

https://huggingface.co/datasets/lapp0/omg_prot50_packed

@lhallee
Copy link

lhallee commented Dec 24, 2024

Awesome, looking good.

That training run is looking much better, getting closer to a reasonable model after 6000 steps. It's 524,288 tokens per batch right? I suppose that's not as large as I originally thought (524,288 sequences).

Let me know if the val_loss metric aligns with your validation set loss.

So the benchmark is done on the test set of this dataset.

The validation set is also a fine benchmark but the test set has additional sequences that were discovered after ESM2 or OMGprot50 were created. That being said, it depends on how this was created if the val_loss aligns:

        hf_hub_download(repo_id="lapp0/omg_prot50_packed", filename=fname,
                        repo_type="dataset", local_dir=local_dir)
get("omgprot50_val_%06d.bin" % 0)

In my opinion, would be good to check validation loss throughout training and then test loss with the sequence reconstruction metrics for the final score. If you let me know how those .bin files were made here I can add the final test eval and metrics to the end.

Either way, I will also conduct the benchmark on the validation set so we have those numbers. That will be added to the readme soon.

@lapp0
Copy link
Collaborator Author

lapp0 commented Dec 24, 2024

It's 524,288 tokens per batch right?

Correct.

If you let me know how those .bin files were made here I can add the final test eval and metrics to the end.

This is just an upload of the output from data/omgprot50.py

@lhallee
Copy link

lhallee commented Dec 24, 2024

Change tokens from uint8 -> uint16 to resolve a bug in packing data

What was this btw?

val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons

There are 2,077,660 validation tokens (without specials) and 3,659,599 test tokens.

split = "val" if shard_index == 0 else "train"

It looks like the first shard of the training set is being kept as the validation split here.

Can take a look at this file tomorrow. Looks like what's currently on the docket is

  • make sure the packed dataset has correctly generated evaluation splits
  • validate on the validation set during training, test on test set at the end
  • add return_logits argument so that we can calculate metrics when necessary

If you get to this before me obviously feel free.

Going to merge for now. Thanks again for all of your help, and happy holidays if you are celebrating anything :)

@lhallee lhallee merged commit 4520516 into master Dec 24, 2024
@lhallee
Copy link

lhallee commented Dec 27, 2024

Hey @lapp0!

Have some really promising results after 20000 steps (~9 hours on 1 H200) beating ESM2-150M at 141M parameters!

This is already a massive speed up, considering I think this is about (or less than) 1e19 FLOPs and the original ESM2-150 was between 1e20-1e21 FLOPs.

However, this was on the validation set. My custom test set evaluation seems to be bugged, as its loss is much higher. Could be an issue with the sliding window (which I've never messed with before), with my PyTorch Dataset, or possibly with the batching. Can't seem to find the issue, would appreciate you taking a look if you have the time.

@lapp0
Copy link
Collaborator Author

lapp0 commented Dec 27, 2024

Happy holidays to you as well!

Change tokens from uint8 -> uint16 to resolve a bug in packing data

Changing the token type from uint16 to uint8 would halve the size of the cached inputs without affecting functionality. The bug with uint8 is that the .bin files’ headers report 100 million tokens, but in practice, there are usually a few hundred fewer. This discrepancy might be a simple oversight or something more involved, although I haven’t investigated it thoroughly.

However, this was on the validation set. My custom test set evaluation seems to be bugged, as its loss is much higher. Could be an issue with the sliding window (which I've never messed with before), with my PyTorch Dataset, or possibly with the batching.

Exciting to see potentially high quality results! I'll take a look at the discrepancy this weekend.

@lhallee
Copy link

lhallee commented Dec 27, 2024

Change tokens from uint8 -> uint16 to resolve a bug in packing data

I was actually able to fix this, so no worries. Now each 100,000,000 tokens is 100mb, which is great.

Exciting to see potentially high quality results! I'll take a look at the discrepancy this weekend.

Awesome. I'll let you know if I find anything.

@lhallee
Copy link

lhallee commented Dec 28, 2024

Opened up this so we can figure out the test set

@lhallee lhallee deleted the working-baseline branch June 20, 2025 19:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants