Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def advance(self): # advance to next data shard
self.next_shard = (self.next_shard + 1) % len(self.files)

def next_batch(self):
buf = self.tokens[self.pos + self.rank * self.local_batch_size:][:self.local_batch_size + 1]
buf = self.tokens[self.pos + self.rank * self.local_batch_size:][:self.local_batch_size]
# by @YouJiacheng: host side async is sufficient;
# no performance improvement was observed when introducing a separate stream.
sequence = buf.to(device="cuda", dtype=torch.int32, non_blocking=True) # inputs
# advance current position and load next shard if necessary
self.pos += self.batch_size
if self.pos + self.batch_size + 1 >= len(self.tokens):
if self.pos + self.batch_size >= len(self.tokens):
self.advance()
return sequence

Expand All @@ -60,25 +60,27 @@ def __init__(self, filename_pattern, seq_len, process_rank, num_processes, eos_i
def advance(self):
self.pos = 0

if self.next_shard // len(self.files) >= self.max_epochs:
raw_tokens = self._leftover_tokens
else:
self.next_shard += 1
# handle epoch limit
if self.next_shard // len(self.files) < self.max_epochs:
raw_tokens = _load_data_shard(self.files[self.next_shard % len(self.files)])
raw_tokens = torch.cat([self._leftover_tokens, raw_tokens], dim=0)

self.next_shard += 1
else:
raw_tokens = self._leftover_tokens
if not raw_tokens.numel():
self._leftover_tokens = torch.empty(0, dtype=torch.uint8)
self.tokens = None
self.tokens = torch.empty(0, dtype=torch.uint8)
return

processed_chunks = []
curr_batch_len = 0

eos_positions = (raw_tokens == self.eos_id).nonzero(as_tuple=True)[0]
for i in range(len(eos_positions)-1):
sample_end = eos_positions[i+1]
sample = raw_tokens[eos_positions[i]+1:sample_end+1] # One sample: "CLS ... EOS"

for i in range(len(eos_positions)):
curr_eos = eos_positions[i]
prev_eos_plus_one = 0 if i == 0 else eos_positions[i-1] + 1 # EOS_idx + 1 = CLS_idx
sample = raw_tokens[prev_eos_plus_one:curr_eos+1] # One sample: "CLS ... EOS"

assert sample[0] == 0 and sample[-1] == 2, (sample[0], sample[-1])
assert curr_batch_len < self.local_batch_size, curr_batch_len
Expand All @@ -101,12 +103,5 @@ def advance(self):
processed_chunks.append(sample)
curr_batch_len += len(sample)

self._leftover_tokens = raw_tokens[sample_end+1:]
self._leftover_tokens = raw_tokens[curr_eos+1:]
self.tokens = torch.cat(processed_chunks, dim=0)

def next_batch(self):
if self.tokens is None:
return None

seq = super().next_batch()
return seq
65 changes: 43 additions & 22 deletions train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@

def get_args():
parser = argparse.ArgumentParser(description='ESM2 training arguments')

# Model hyperparams
parser.add_argument('--vocab_size', type=int, default=33, help='vocabulary size')
parser.add_argument('--num_hidden_layers', type=int, default=24, help='number of transformer layers')
parser.add_argument('--num_attention_heads', type=int, default=6, help='number of attention heads (head dim 128 suggested by @Grad62304977)')
parser.add_argument('--hidden_size', type=int, default=768, help='model hidden dimension size')

# Data hyperparams
parser.add_argument('--input_bin', type=str, default='data/omgprot50/omgprot50_train_*.bin', help='input .bins to train on')
parser.add_argument('--input_valid_bin', type=str, default='data/omgprot50/omgprot50_valid_*.bin', help='input .bins to eval validation loss on')
parser.add_argument('--input_test_bin', type=str, default='data/omgprot50/omgprot50_test_*.bin', help='input .bins to eval test loss on')
parser.add_argument('--input_test_bin', type=str, default='data/omgprot50/omgprot50_test_*.bin', help='input .bins to eval test loss on')

# Optimization hyperparams
parser.add_argument('--batch_size', type=int, default=8*64*1024, help='batch size, in tokens, across all devices')
parser.add_argument('--grad_accum', type=int, default=1, help='manually set number of gradient accumulation steps, else, will be ddp_world_size')
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_param_count(model):
master_process = (ddp_rank == 0)
else:
ddp_rank = 0
ddp_local_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
device = torch.device('cuda:0')
torch.cuda.set_device(device)
Expand Down Expand Up @@ -149,7 +149,7 @@ def print0(s, logonly=False):
assert ddp_world_size == 1 or args.grad_accum == 1, 'Cannot currently use both DDP and gradient accumulation'
if ddp_world_size > 1:
train_accumulation_steps = ddp_world_size
batch_size = args.batch_size // ddp_world_size
batch_size = args.batch_size // ddp_world_size
elif args.grad_accum > 1:
train_accumulation_steps *= args.grad_accum
batch_size = args.batch_size // args.grad_accum
Expand Down Expand Up @@ -248,18 +248,20 @@ def get_lr(it):
# run validation batches
model.eval()
valid_loader.reset()
val_loss, valid_steps, valid_tokens = 0.0, 0, 0
val_loss, valid_tokens = 0.0, 0
with torch.no_grad():
while (input_ids := valid_loader.next_batch()) is not None:
valid_steps += 1
valid_tokens += (input_ids != 1).sum()
val_loss += model(input_ids, sliding_window_size, mlm_probability=0.15)
input_ids = valid_loader.next_batch()
while input_ids.numel():
batch_valid_tokens = (input_ids != pad_id).sum()
valid_tokens += batch_valid_tokens
val_loss += model(input_ids, sliding_window_size) * batch_valid_tokens
input_ids = valid_loader.next_batch()
if ddp_world_size > 1:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(valid_tokens, op=dist.ReduceOp.SUM)
val_loss /= valid_steps
val_loss /= valid_tokens
# log val loss to console and to logfile
print0(f'step:{step}/{args.num_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms perplexity:{(math.e**val_loss):.4f} param_count:{get_param_count(model):,} tokens: {valid_tokens.item()}')
print0(f'step:{step}/{args.num_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms perplexity:{(math.e**val_loss):.4f} param_count:{get_param_count(model):,} tokens: {valid_tokens.item():,}')
# start the clock again
torch.cuda.synchronize()
t0 = time.perf_counter()
Expand Down Expand Up @@ -334,19 +336,38 @@ def get_lr(it):
model.eval()
test_loader.reset()

test_loss, test_steps, test_tokens = 0.0, 0, 0
test_loss, test_tokens = 0.0, 0
with torch.no_grad():
input_ids = test_loader.next_batch()
while input_ids.numel():
batch_test_tokens = (input_ids != pad_id).sum()
test_tokens += batch_test_tokens
test_loss += model(input_ids, sliding_window_size, mlm_probability=0.15) * batch_test_tokens
input_ids = test_loader.next_batch()
if ddp_world_size > 1:
dist.all_reduce(test_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(test_tokens, op=dist.ReduceOp.SUM)
test_loss /= test_tokens

original_test_loss = test_loss
print0(f"Original test loss (regular forward pass): {original_test_loss:.4f}")

test_loss, test_tokens = 0.0, 0
all_logits, all_labels = [], []
with torch.no_grad():
while (input_ids := test_loader.next_batch()) is not None:
test_steps += 1
test_tokens += (input_ids != 1).sum()
input_ids = test_loader.next_batch()
while input_ids.numel():
batch_test_tokens = (input_ids != pad_id).sum()
test_tokens += batch_test_tokens
logits, loss, labels = model.inference(input_ids, sliding_window_size, mlm_probability=0.15)
test_loss += loss * batch_test_tokens
all_logits.extend(logits.detach().cpu().flatten().tolist())
all_labels.extend(labels.detach().cpu().flatten().tolist())
input_ids = test_loader.next_batch()
if ddp_world_size > 1:
dist.all_reduce(test_loss, op=dist.ReduceOp.AVG)
dist.all_reduce(test_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(test_tokens, op=dist.ReduceOp.SUM)
test_loss /= test_steps
test_loss /= test_tokens

import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, matthews_corrcoef
Expand All @@ -362,14 +383,14 @@ def get_lr(it):
test_accuracy = accuracy_score(all_labels, all_logits)
test_mcc = matthews_corrcoef(all_labels, all_logits)

print0(f"Test results (inference pass): {test_tokens.item()}")
print0(f'Test tokens: {test_tokens.item()}')
print0(f'Loss: {test_loss:.4f} | Perplexity: {math.e**test_loss:.4f}')
print0(f'Precision: {test_precision:.4f} | Recall: {test_recall:.4f} | F1: {test_f1:.4f} | Accuracy: {test_accuracy:.4f} | MCC: {test_mcc:.4f}')
print0(f'Train Time: {training_time_ms:.0f}ms | Step Avg: {training_time_ms/(timed_steps-1):.2f}ms | Param Count: {get_param_count(model):,}')p
print0(f'Total train time (min): {training_time_ms / 60000:.2f}')
print0(f'Total train time (hours): {training_time_ms / 3600000:.2f}')

print0(f'peak memory consumption testing: {torch.cuda.max_memory_allocated() // 1024 // 1024 // 1024} GiB')
print0(f"peak memory consumption testing: {torch.cuda.max_memory_allocated() // 1024 // 1024 // 1024} GiB")
# -------------------------------------------------------------------------
# clean up nice
if ddp_world_size > 1:
Expand Down