Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# Replicating ESM2 at the speed of sound
This repo is an open-source collaboration to reproduce ESM2-150M validation loss in as little time as possible inspired by the fantastic [modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt) repo.
This repo is an open-source collaboration to reproduce ESM2-150M validation loss in as little time as possible inspired by the fantastic [modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt) repo.

## Quick Start

Setup environment and train ESM2

```
git clone https://github.com/Synthyra/SpeedRunningESM2 && cd SpeedRunningESM2
pip install -r requirements.txt
pip install --pre torch==2.6.0.dev20241203+cu124 --index-url https://download.pytorch.org/whl/nightly/cu124 --upgrade # install torch 2.6.0
python data/cached_omgprot50.py 10 # downloads only the first 1.0B training tokens to save time
./run.sh
```

## Benchmarks to beat
[OMGprot50](https://huggingface.co/datasets/Synthyra/omg_prot50) test set, 15% MLM objective.
Expand Down
22 changes: 7 additions & 15 deletions data/cached_omgprot50.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
import os
import sys
from huggingface_hub import hf_hub_download


# Download the omgprot50 tokens from huggingface. This
# saves about an hour of startup time compared to regenerating them.
def get(fname):
local_dir = os.path.join(os.path.dirname(__file__), 'omgprot50')
if not os.path.exists(os.path.join(local_dir, fname)):
hf_hub_download(repo_id="Synthyra/omg_prot50", filename=fname, repo_type="dataset", local_dir=local_dir)


get("data/valid-00000-of-00001.parquet")
get("data/test-00000-of-00001.parquet")
# Full omgprot50, which is roughly 52 billion tokens
# Each chunk is ~2.3 million sequences, ~600,000,000 tokens, 490 MB
num_chunks = 91


hf_hub_download(repo_id="lapp0/omg_prot50_packed", filename=fname,
repo_type="dataset", local_dir=local_dir)
get("omgprot50_val_%06d.bin" % 0)
num_chunks = 442 # Each chunk is 100M tokens
if len(sys.argv) >= 2: # we can pass an argument to download less
num_chunks = int(sys.argv[1])


for i in range(1, num_chunks+1):
get(f"data/train-{i:05d}-of-00091.parquet")
get("omgprot50_train_%06d.bin" % i)
13 changes: 7 additions & 6 deletions data/omgprot50.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def write_datafile(filename, toks):
"""
"""
Saves token data as a .bin file, for reading in C.
- First comes a header with 256 int32s
- The tokens follow, each as a uint16
Expand Down Expand Up @@ -61,20 +61,21 @@ def write_datafile(filename, toks):
enc = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

def tokenize(doc):
# tokenizes a single document and returns a numpy array of uint16 tokens
tokens = enc.encode(doc["sequence"], add_special_tokens=True).input_ids
# tokenizes a single document and returns a numpy array of uint8 tokens
tokens = enc.encode(doc["sequence"], add_special_tokens=True)
assert tokens[0] == 0 and tokens.count(0) == 1, "CLS token should always be at start and only at start"
tokens_np = np.array(tokens)
assert (0 <= tokens_np).all() and (tokens_np < 2**8).all(), "token dictionary too large for uint8"
tokens_np_uint8 = tokens_np.astype(np.uint8) # can use uint8 because only 33 tokens
return tokens_np_uint8
tokens_np_uint16 = tokens_np.astype(np.uint16) # can use uint8 because only 33 tokens
return tokens_np_uint16


# tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system
with mp.Pool(nprocs) as pool:
shard_index = 0
# preallocate buffer to hold current shard
all_tokens_np = np.empty((args.shard_size,), dtype=np.uint8)
all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)
token_count = 0
progress_bar = None
for tokens in pool.imap(tokenize, fw, chunksize=16):
Expand Down
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
torchrun --standalone --nproc_per_node=8 train_gpt2.py
torchrun --standalone --nproc_per_node=8 train_esm2.py
Loading