Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data/download_omgprot50.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def get(fname):
args = parser.parse_args()
get("omgprot50_valid_%06d.bin" % 0)
get("omgprot50_test_%06d.bin" % 0)
for i in range(1, args.num_chunks+1):
for i in range(0, args.num_chunks+1):
get("omgprot50_train_%06d.bin" % i)
160 changes: 78 additions & 82 deletions dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,115 +2,111 @@
from pathlib import Path


def _peek_data_shard(file: Path):
def _load_data_shard(file: Path):
# only reads the header, returns header data
# header is 256 int32
header = torch.from_file(f"{file}", False, 256, dtype=torch.int32)
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
return int(header[2]) # number of tokens (claimed)


def _load_data_shard(path: Path, num_tokens):
with path.open("rb", buffering=0) as f:
assert header[0] == 20240520, 'magic number mismatch in the data .bin file'
assert header[1] == 1, 'unsupported version'
num_tokens = int(header[2]) # number of tokens (claimed)
with file.open('rb', buffering=0) as f:
tokens = torch.empty(num_tokens, dtype=torch.uint8, pin_memory=True)
f.seek(256 * 4)
nbytes = f.readinto(tokens.numpy())
assert nbytes == num_tokens, "number of tokens read does not match header?"
assert nbytes == num_tokens, 'number of tokens read does not match header?'
return tokens


class DistributedDataLoader:
def __init__(self, filename_pattern, batch_size, process_rank, num_processes):
self.process_rank = process_rank
self.num_processes = num_processes
self.batch_size = batch_size

# glob files that match the pattern
def __init__(self, filename_pattern: str, batch_size: int, rank: int, world_size: int):
assert batch_size % world_size == 0
self.world_size = world_size
self.rank = rank
self.files = sorted(Path.cwd().glob(filename_pattern))
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"

# load and validate all data shards, count number of tokens in total
self.files_num_tokens = [_peek_data_shard(file) for file in self.files]
self.total_num_tokens = sum(self.files_num_tokens)
self.batch_size = batch_size
self.local_batch_size = self.batch_size // self.world_size

self.reset()

def reset(self):
self.current_shard = -1
self.next_shard = 0
self.advance()

def advance(self): # advance to next data shard
self.current_shard = (self.current_shard + 1) % len(self.files)
self.current_position = self.process_rank * self.batch_size
self.tokens = _load_data_shard(self.files[self.current_shard], self.files_num_tokens[self.current_shard])
self.pos = 0
self.tokens = _load_data_shard(self.files[self.next_shard])
self.next_shard = (self.next_shard + 1) % len(self.files)

def next_batch(self):
batch_size = self.batch_size * self.num_processes
buf = self.tokens[self.current_position:self.current_position+self.batch_size]
# host side async is sufficient;
buf = self.tokens[self.pos + self.rank * self.local_batch_size:][:self.local_batch_size + 1]
# by @YouJiacheng: host side async is sufficient;
# no performance improvement was observed when introducing a separate stream.
input_ids = buf.to(device="cuda", dtype=torch.int32, non_blocking=True) # inputs
# advance current position and load next shard if necessary
self.current_position += batch_size
if self.current_position + batch_size >= len(self.tokens):
self.advance()
return input_ids


class DistributedDataLoaderTrain(DistributedDataLoader):
def __init__(self, filename_pattern, seq_len, process_rank, num_processes, eos_id, max_length=1024):
super().__init__(filename_pattern, seq_len, process_rank, num_processes)
self.eos_id = eos_id
self.max_length = max_length

def reset(self):
self.current_shard = self.process_rank - self.num_processes
self.advance()

def advance(self): # advance to next data shard
self.current_shard = (self.current_shard + self.num_processes) % len(self.files)
self.current_position = 0
self.tokens = _load_data_shard(self.files[self.current_shard], self.files_num_tokens[self.current_shard])

def next_batch(self):
end_pos = self.current_position + self.batch_size
buf = self.tokens[self.current_position:end_pos]
input_ids = buf.to(device="cuda", dtype=torch.int32, non_blocking=True)
keep = (input_ids == self.eos_id).cumsum(dim=0).argmax().item()
keep = max(keep or 0, self.batch_size - self.max_length)
sequence = buf.to(device="cuda", dtype=torch.int32, non_blocking=True) # inputs
# advance current position and load next shard if necessary
self.current_position += keep
if self.current_position + self.batch_size >= len(self.tokens):
self.pos += self.batch_size
if self.pos + self.batch_size + 1 >= len(self.tokens):
self.advance()
return input_ids
return sequence


class DistributedDataLoaderEval(DistributedDataLoader):
def __init__(self, filename_pattern, seq_len, process_rank, num_processes, eos_id, pad_id, max_length=1024):
super().__init__(filename_pattern, seq_len, process_rank, num_processes)
class DistributedPaddedDataLoader(DistributedDataLoader):
def __init__(self, filename_pattern, seq_len, process_rank, num_processes, eos_id, pad_id, max_epochs=1):
self.eos_id = eos_id
self.pad_id = pad_id
self.max_length = max_length

def reset(self):
self.current_shard = self.process_rank - self.num_processes
self.advance()
self._leftover_tokens = torch.empty(0, dtype=torch.uint8)
self.max_epochs = max_epochs
super().__init__(filename_pattern, seq_len, process_rank, num_processes)

def advance(self): # advance to next data shard
self.current_shard = (self.current_shard + self.num_processes) % len(self.files)
self.current_position = 0
self.tokens = _load_data_shard(self.files[self.current_shard], self.files_num_tokens[self.current_shard])
def advance(self):
self.pos = 0

if self.next_shard // len(self.files) >= self.max_epochs:
raw_tokens = self._leftover_tokens
else:
self.next_shard += 1
raw_tokens = _load_data_shard(self.files[self.next_shard % len(self.files)])
raw_tokens = torch.cat([self._leftover_tokens, raw_tokens], dim=0)

if not raw_tokens.numel():
self._leftover_tokens = torch.empty(0, dtype=torch.uint8)
self.tokens = None
return

processed_chunks = []
curr_batch_len = 0

eos_positions = (raw_tokens == self.eos_id).nonzero(as_tuple=True)[0]
for i in range(len(eos_positions)-1):
sample_end = eos_positions[i+1]
sample = raw_tokens[eos_positions[i]+1:sample_end+1] # One sample: "CLS ... EOS"

assert sample[0] == 0 and sample[-1] == 2, (sample[0], sample[-1])
assert curr_batch_len < self.local_batch_size, curr_batch_len

# if adding sample exceeds the batch size resulting in truncation, pad to end of batch, starting a fresh batch
if len(sample) + curr_batch_len >= self.local_batch_size:
num_pad = self.local_batch_size - curr_batch_len
processed_chunks.append(torch.full((num_pad,), self.pad_id))
curr_batch_len = 0

# if len(sample) > local batch size, chunk evenly, making multiple padded batches, starting a fresh batch
if len(sample) > self.local_batch_size:
for split_sample in torch.chunk(sample, len(sample) // self.local_batch_size + 1):
processed_chunks.append(split_sample)
num_pad = self.local_batch_size - len(split_sample)
processed_chunks.append(torch.full((num_pad,), self.pad_id))
curr_batch_len = 0
continue

processed_chunks.append(sample)
curr_batch_len += len(sample)

self._leftover_tokens = raw_tokens[sample_end+1:]
self.tokens = torch.cat(processed_chunks, dim=0)

def next_batch(self):
end_pos = self.current_position + self.batch_size
buf = self.tokens[self.current_position:end_pos]
input_ids = buf.to(device="cuda", dtype=torch.int32, non_blocking=True)
keep = (input_ids == self.eos_id).cumsum(dim=0).argmax().item()
keep = max(keep or 0, self.batch_size - self.max_length)
input_ids[keep + 1:] = self.pad_id
# advance current position and load next shard if necessary
self.current_position += keep
if self.current_position + self.batch_size >= len(self.tokens):
self.advance()
return input_ids
if self.tokens is None:
return None

seq = super().next_batch()
return seq
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ def forward(self, x: torch.Tensor, vi: torch.Tensor, x0: torch.Tensor, block_mas


class ValueEmbedding(nn.Module):
def __init__(self, config: "ModelConfig"):
def __init__(self, config: "ModelConfig", padding_idx):
super().__init__()
self.embed = nn.ModuleList([
nn.Embedding(config.vocab_size, config.hidden_size)
nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=padding_idx)
for _ in range(config.num_hidden_layers // 2)
])

Expand Down Expand Up @@ -169,11 +169,11 @@ def __init__(self, config: ModelConfig):
# Add learnable skip connection weights for decoder layers
self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))

self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
self.embed = nn.Embedding(self.vocab_size, config.hidden_size, padding_idx=tokenizer.pad_token_id)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)])
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning
# U-net structure on token value embeddings by @leloykun
self.value_embeds = ValueEmbedding(config)
self.value_embeds = ValueEmbedding(config, padding_idx=tokenizer.pad_token_id)
self.lm_head = CastedLinear(config.hidden_size, self.vocab_size)
self.lm_head.weight.data.zero_() # @Grad62304977
self.cross_entropy = nn.CrossEntropyLoss()
Expand Down
52 changes: 29 additions & 23 deletions train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
import torch
import torch.distributed as dist
import torch._inductor.config as config
from transformers import EsmTokenizer
from torch.nn.parallel import DistributedDataParallel as DDP
from pathlib import Path

from optimizer import Muon
from model import ModelConfig, ESM, CastedLinear
from dataloading import DistributedDataLoader
from dataloading import DistributedPaddedDataLoader


def get_args():
Expand Down Expand Up @@ -159,19 +160,16 @@ def print0(s, logonly=False):
print0(f'Total batch size: {args.batch_size} tokens')

# load tokens
train_loader = DistributedDataLoader(args.input_bin, batch_size, ddp_rank, ddp_world_size)
valid_loader = DistributedDataLoader(args.input_valid_bin, batch_size, ddp_rank, ddp_world_size)
test_loader = DistributedDataLoader(args.input_test_bin, batch_size, ddp_rank, ddp_world_size)
print0(f"Training DataLoader: total number of tokens: {train_loader.total_num_tokens} across {len(train_loader.files)} files")
print0(f"Validation DataLoader: total number of tokens: {valid_loader.total_num_tokens} across {len(valid_loader.files)} files")
print0(f"Testing DataLoader: total number of tokens: {test_loader.total_num_tokens} across {len(test_loader.files)} files")
tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
eos_id, pad_id = tokenizer.eos_token_id, tokenizer.pad_token_id
train_loader = DistributedPaddedDataLoader(args.input_bin, batch_size, ddp_rank, ddp_world_size, eos_id=eos_id, pad_id=pad_id)
valid_loader = DistributedPaddedDataLoader(args.input_valid_bin, batch_size, ddp_rank, ddp_world_size, eos_id=eos_id, pad_id=pad_id)
test_loader = DistributedPaddedDataLoader(args.input_test_bin, batch_size // 8, ddp_rank, ddp_world_size, eos_id=eos_id, pad_id=pad_id)
print0(f"Training DataLoader: {len(train_loader.files)} files")
print0(f"Validation DataLoader: {len(valid_loader.files)} files")
print0(f"Testing DataLoader: {len(test_loader.files)} files")
print0('='*100, logonly=True)

valid_steps = valid_loader.total_num_tokens // args.batch_size
test_steps = test_loader.total_num_tokens // args.batch_size

input_ids = train_loader.next_batch()

model = ESM(model_config)
model = model.cuda().bfloat16()
for m in model.modules():
Expand Down Expand Up @@ -250,16 +248,18 @@ def get_lr(it):
# run validation batches
model.eval()
valid_loader.reset()
val_loss = 0.0
val_loss, valid_steps, valid_tokens = 0.0, 0, 0
with torch.no_grad():
for _ in range(valid_steps):
input_ids = valid_loader.next_batch()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overwriting input_ids results in training on the test set every val step.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just move input_ids = train_loader.next_batch() to right before we pass input ides to model during training right?

val_loss += model(input_ids, sliding_window_size, mlm_probability=0.15)
while (input_ids := valid_loader.next_batch()) is not None:
valid_steps += 1
valid_tokens += (input_ids != 1).sum()
val_loss += model(input_ids, sliding_window_size)
if ddp_world_size > 1:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
dist.all_reduce(valid_tokens, op=dist.ReduceOp.SUM)
val_loss /= valid_steps
# log val loss to console and to logfile
print0(f'step:{step}/{args.num_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms perplexity:{(math.e**val_loss):.4f} param_count:{get_param_count(model):,}')
print0(f'step:{step}/{args.num_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms perplexity:{(math.e**val_loss):.4f} param_count:{get_param_count(model):,} tokens: {valid_tokens.item()}')
# start the clock again
torch.cuda.synchronize()
t0 = time.perf_counter()
Expand Down Expand Up @@ -296,8 +296,8 @@ def get_lr(it):
stack.enter_context(model.no_sync())
#if step >= 5:
# stack.enter_context(torch.compiler.set_stance(skip_guard_eval_unsafe=True))
model(input_ids, sliding_window_size, mlm_probability=0.20).backward()
input_ids = train_loader.next_batch()
model(input_ids, sliding_window_size).backward()
if train_accumulation_steps != 1:
for p in model.parameters():
p.grad /= train_accumulation_steps
Expand Down Expand Up @@ -333,13 +333,19 @@ def get_lr(it):
torch.manual_seed(42)
model.eval()
test_loader.reset()
test_loss = 0.0
with torch.no_grad():
for _ in range(test_steps):
input_ids = test_loader.next_batch()
test_loss += model(input_ids, sliding_window_size, mlm_probability=0.15)

test_loss, test_steps, test_tokens = 0.0, 0, 0
with torch.no_grad():
while (input_ids := test_loader.next_batch()) is not None:
test_steps += 1
test_tokens += (input_ids != 1).sum()
test_loss += model(input_ids, sliding_window_size)
if ddp_world_size > 1:
dist.all_reduce(test_loss, op=dist.ReduceOp.AVG)
dist.all_reduce(test_tokens, op=dist.ReduceOp.SUM)
test_loss /= test_steps

print0(f"Test tokens: {test_tokens.item()}")
print0(f"Test results | Loss: {test_loss:.4f} | Perplexity: {math.e**test_loss:.4f}")
print0(f"Total train time (min): {training_time_ms / 60000:.2f}")
print0(f"Total train time (hours): {training_time_ms / 3600000:.2f}")
Expand Down