-
Notifications
You must be signed in to change notification settings - Fork 1
Working baseline #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@lhallee where did you get these numbers and how is loss measured? Per the paper cited in the OP (in the "lower bound" bullet), it seems the theoretical lower bound for cross entropy loss is ~2.5 (perplexity of 5.66). https://github.com/Synthyra/SpeedRunningESM2/blob/master/README.md?plain=1#L4-L11 |
|
Hi @lapp0. Really awesome work, thanks for your continued effort! I got the loss from the script I wrote for the benchmark here. https://www.biorxiv.org/content/10.1101/2022.07.20.500902v2.full.pdf What is the motivation behind the BOS token? ESM has a CLS and EOS token which should suffice right? Also, curious behind the single token being masked per document. Since ESM2 the field has somewhat converged on 20% random masking with no replacement which may work better. |
|
12% and 1.5% are just the resolved probabilities of what you quoted "Each token has a 15% probability of inclusion. If included the tokens have an 80% probability of being replaced with a mask token, a 10% probability of being replaced with a random token" Perhaps part of the issue is that I don't incorporate the 1.5% chance of "being replaced with an unmasked token". Will try incorporating that.
CLS should suffice, I missed that the tokenizer was producing this since your dataloader incorporates the
This was based on reading on a different ESM variant, however I've reverted the change to reflect the ESM2 paper (12%, 1.5%, 1.5%) |
With exp^loss for perplexity looks like the lower bound is around 1.73, which makes more sense in the context of the benchmarked results and other training experiments. Seeing lower than CE 2.5 quite often and sub 2 sometimes, I suppose it also depends on the mask percentage.
That makes sense!
We may get more bang for our buck by maximizing steps instead of batch size. Not sure though. |
Checks out to me.
This is modded-nanogpt batch size. The Muon optimizer docstring suggests it is only tested on very large batch sizes, although it's worth experimenting with smaller sizes. This PR should be good to try out. Let me know if the https://gist.github.com/lapp0/e076d696df147c7df8028cb2069300d4 |
|
Awesome, looking good. That training run is looking much better, getting closer to a reasonable model after 6000 steps. It's 524,288 tokens per batch right? I suppose that's not as large as I originally thought (524,288 sequences).
So the benchmark is done on the test set of this dataset. The validation set is also a fine benchmark but the test set has additional sequences that were discovered after ESM2 or OMGprot50 were created. That being said, it depends on how this was created if the hf_hub_download(repo_id="lapp0/omg_prot50_packed", filename=fname,
repo_type="dataset", local_dir=local_dir)
get("omgprot50_val_%06d.bin" % 0)In my opinion, would be good to check validation loss throughout training and then test loss with the sequence reconstruction metrics for the final score. If you let me know how those .bin files were made here I can add the final test eval and metrics to the end. Either way, I will also conduct the benchmark on the validation set so we have those numbers. That will be added to the readme soon. |
Correct.
This is just an upload of the output from |
What was this btw?
There are 2,077,660 validation tokens (without specials) and 3,659,599 test tokens.
It looks like the first shard of the training set is being kept as the validation split here. Can take a look at this file tomorrow. Looks like what's currently on the docket is
If you get to this before me obviously feel free. Going to merge for now. Thanks again for all of your help, and happy holidays if you are celebrating anything :) |
|
Hey @lapp0! Have some really promising results after 20000 steps (~9 hours on 1 H200) beating ESM2-150M at 141M parameters! This is already a massive speed up, considering I think this is about (or less than) 1e19 FLOPs and the original ESM2-150 was between 1e20-1e21 FLOPs. However, this was on the validation set. My custom test set evaluation seems to be bugged, as its loss is much higher. Could be an issue with the sliding window (which I've never messed with before), with my PyTorch Dataset, or possibly with the batching. Can't seem to find the issue, would appreciate you taking a look if you have the time. |
|
Happy holidays to you as well!
Changing the token type from
Exciting to see potentially high quality results! I'll take a look at the discrepancy this weekend. |
I was actually able to fix this, so no worries. Now each 100,000,000 tokens is 100mb, which is great.
Awesome. I'll let you know if I find anything. |
|
Opened up this so we can figure out the test set |

Changes
python data/cached_omgprot50.py 10workingBaselines