Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def __init__(self, config: ModelConfig):
super().__init__(config)
self.config = config
tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
self.masker = ProteinMasker(tokenizer, 0.20) # 20% masking rate https://arxiv.org/abs/2301.06568
self.inference_masker = ProteinMasker(tokenizer, 0.15) # 15% masking rate for inference, ESM2
self.masker = ProteinMasker(tokenizer, 0.15) # 20% masking rate https://arxiv.org/abs/2301.06568
#self.inference_masker = ProteinMasker(tokenizer, 0.15) # 15% masking rate for inference, ESM2
self.cls_id = tokenizer.cls_token_id
self.vocab_size = tokenizer.vocab_size
self.num_hidden_layers = config.num_hidden_layers
Expand Down Expand Up @@ -275,12 +275,12 @@ def doc_mask_mod(b, h, q_idx, kv_idx):
return self.get_logits(x)

def inference(self, input_ids: torch.Tensor, sliding_window_size: torch.Tensor = None) -> Tuple[torch.Tensor, Any, Any]:
input_ids, labels = self.inference_masker(input_ids)
input_ids, labels = self.masker(input_ids)
logits = self.flex_forward(input_ids, sliding_window_size)
loss = None
if labels is not None:
loss = self.cross_entropy(logits.view(-1, self.vocab_size), labels.view(-1).long())
return logits, loss, labels
return logits.cpu(), loss.cpu(), labels.cpu()

def forward(self, input_ids: torch.Tensor, sliding_window_size: torch.Tensor) -> torch.Tensor:
input_ids, labels = self.masker(input_ids)
Expand Down
46 changes: 33 additions & 13 deletions train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from pathlib import Path
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, matthews_corrcoef
import warnings
warnings.filterwarnings("ignore")

from optimizer import Muon
from model import ModelConfig, ESM, CastedLinear
Expand Down Expand Up @@ -59,7 +61,7 @@ def get_args():
parser.add_argument('--cooldown_steps', type=int, default=1000, help='number of cooldown steps')

# Evaluation and logging hyperparams
parser.add_argument('--valid_loss_every', type=int, default=1000, help='every how many steps to evaluate val loss? 0 for only at the end')
parser.add_argument('--eval_every', type=int, default=1000, help='every how many steps to evaluate val loss? 0 for only at the end')
parser.add_argument('--hf_model_name', type=str, default='Synthyra/esm_speedrun', help='huggingface model name')
parser.add_argument('--token', type=str, default=None, help='huggingface token')
parser.add_argument('--save_every', type=int, default=None, help='save every how many steps? None for no saving')
Expand Down Expand Up @@ -161,7 +163,7 @@ def print0(s, logonly=False):
# load tokens
train_loader = DistributedDataLoader(args.input_bin, batch_size, ddp_rank, ddp_world_size)
valid_loader = DistributedDataLoader(args.input_valid_bin, batch_size, ddp_rank, ddp_world_size)
test_loader = DistributedDataLoader(args.input_test_bin, batch_size // 4, ddp_rank, ddp_world_size)
test_loader = DistributedDataLoader(args.input_test_bin, batch_size, ddp_rank, ddp_world_size)
print0(f"Training DataLoader: total number of tokens: {train_loader.total_num_tokens} across {len(train_loader.files)} files")
print0(f"Validation DataLoader: total number of tokens: {valid_loader.total_num_tokens} across {len(valid_loader.files)} files")
print0(f"Testing DataLoader: total number of tokens: {test_loader.total_num_tokens} across {len(test_loader.files)} files")
Expand Down Expand Up @@ -241,23 +243,46 @@ def get_lr(it):
sw_prev = sw_size

# once in a while evaluate the validation dataset
if args.valid_loss_every > 0 and step % args.valid_loss_every == 0 or last_step:
if args.eval_every > 0 and step % args.eval_every == 0 or last_step:
# stop the clock
torch.cuda.synchronize()
training_time_ms += 1000 * (time.perf_counter() - t0)
# run validation batches
model.eval()
valid_loader.reset()
val_loss = 0.0
val_true, val_pred = [], []
with torch.no_grad():
for _ in range(valid_steps):
input_ids = valid_loader.next_batch()
val_loss += model(input_ids, sliding_window_size)
logits, loss, labels = model.inference(input_ids, sliding_window_size)
val_true.extend(labels.cpu().numpy().flatten())
val_pred.extend(logits.argmax(dim=-1).cpu().numpy().flatten())
val_loss += loss.detach().cpu().item()
if ddp_world_size > 1:
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
val_loss /= valid_steps
# log val loss to console and to logfile
print0(f'step:{step}/{args.num_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms perplexity:{(math.e**val_loss):.4f} param_count:{get_param_count(model):,}')
val_perplexity = torch.exp(torch.tensor(val_loss)).item()

# Calculate validation metrics
val_true = np.array(val_true)
val_pred = np.array(val_pred)
mask = (val_true != -100)
val_true = val_true[mask]
val_pred = val_pred[mask]

val_precision = precision_score(val_true, val_pred, average='weighted')
val_recall = recall_score(val_true, val_pred, average='weighted')
val_f1 = f1_score(val_true, val_pred, average='weighted')
val_accuracy = accuracy_score(val_true, val_pred)
val_mcc = matthews_corrcoef(val_true, val_pred)

# log validation metrics to console
print0(f'step:{step}/{args.num_steps}')
print0(f'Loss: {val_loss:.4f} | Perplexity: {val_perplexity:.4f}')
print0(f'Precision: {val_precision:.4f} | Recall: {val_recall:.4f} | F1: {val_f1:.4f} | Accuracy: {val_accuracy:.4f} | MCC: {val_mcc:.4f}')
print0(f'Train Time: {training_time_ms:.0f}ms | Step Avg: {training_time_ms/(timed_steps-1):.2f}ms | Param Count: {get_param_count(model):,}')

# start the clock again
torch.cuda.synchronize()
t0 = time.perf_counter()
Expand Down Expand Up @@ -351,13 +376,8 @@ def get_lr(it):
mcc = matthews_corrcoef(all_true, all_pred)

print0("Final Results:")
print0(f" Loss: {average_loss:.4f}")
print0(f" Perplexity: {perplexity:.4f}")
print0(f" Precision: {precision:.4f}")
print0(f" Recall: {recall:.4f}")
print0(f" F1: {f1:.4f}")
print0(f" Accuracy: {accuracy:.4f}")
print0(f" MCC: {mcc:.4f}")
print0(f"Loss: {average_loss:.4f}, Perplexity: {perplexity:.4f}")
print0(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}, MCC: {mcc:.4f}")

print0(f"peak memory consumption testing: {torch.cuda.max_memory_allocated() // 1024 // 1024 // 1024} GiB")
# -------------------------------------------------------------------------
Expand Down