Skip to content

Commit

Permalink
Updates for v0.1.0 (see release notes!)
Browse files Browse the repository at this point in the history
  • Loading branch information
tysam-code committed Mar 13, 2023
1 parent d7ea4d2 commit 4846ccd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 26 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ Goals:
* eventual world-record-speed single-GPU training times on at least one LLM benchmark (currently via the A100! :D)


This code implements a fast-training language model baseline that achieves under ~3.8 val loss (44.7 perplexity) on WikiText-103 in just over 6 minutes. It is currently a distilled, feature-pruned, relatively faithful reimplementation of a basic GPT language model as defined in Karpathy's excellent [nanoGPT](https://github.com/karpathy/nanoGPT) repository**. The only differences in this implentation that I know of are different function calls for the attention operation, the full-accuracy Pytorch GELU function, and the native PyTorch OneCycle scheduler. There are also different default hyperparameters specific to these short runs. For the rationale behind the 3.8 val loss target, please see the '[Why 3.8 Val Loss?](#why-3.8-val-loss?)' section.
This code implements a fast-training language model baseline that achieves under ~3.8 val loss (44.7 perplexity) on WikiText-103 in just over 3 minutes. It is currently a relatively minimal model based on relatively faithful reimplementation of a basic GPT language model as defined in Karpathy's excellent [nanoGPT](https://github.com/karpathy/nanoGPT) repository**. We've made a number of changes to improve the training speed on the tiny (~100M) training set that we use. For the rationale behind the 3.8 val loss target, please see the '[Why 3.8 Val Loss?](#why-3.8-val-loss?)' section.


This is a very focused implementation which attempts to maximize code understandability and minimize code length. At the same time, it aims to be very hackable to let people test out new ideas and get back initial results rapidly. We also want to keep this code as accessible as possible to a wide range of users and usecases -- mainly through layout, helpful code comments, and simplicity. We also only target one piece of hardware -- the A100 currently -- but attempt to maintain accessibility by providing options for people with less GPU memory. As as result of all of this, this means that this implementation really doesn't have much in the way of fancy features. It downloads and loads the data, creates the network, runs the training loop and that's about it -- if you want anything on top of that, it should be easily to implement with how modular the code is. That said, feel free to open an issue if there is something critical that I've missed!
This is a very focused codebase which attempts to maximize code understandability and minimize code length. At the same time, it aims to be very hackable to let people test out new ideas and get back initial results rapidly. We also want to keep this code as accessible as possible to a wide range of users and usecases -- mainly through layout, helpful code comments, and simplicity. We also only target one piece of hardware -- the A100 currently -- but attempt to maintain accessibility by providing options for people with less GPU memory. As as result of all of this, this means that this implementation really doesn't have much in the way of fancy features. It downloads and loads the data, creates the network, runs the training loop and that's about it -- if you want anything on top of that, it should be easily to implement with how modular the code is. That said, feel free to open an issue if there is something critical that I've missed!


Finally, this code is meant to be fast -- as fast as possible. Please keep an eye out for further training speedups in future updates, since this is just the baseline after all. This code is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. Part of the recommended workflow at scale is that if you're at an organization that needs a modified 'base repo', to modify this repo and use that as your new base internally. I oftentimes use a branching tree structure several repos deep in my work and I find it to be a great way to rapidly explore/context switch/roll back between different problem-solving domains. It's also another reason why I keep the base repo so simple.
Expand All @@ -44,7 +44,7 @@ Feel free to check out my [Patreon](https://www.patreon.com/user/posts?u=8363213

### Known Bugs / Potential Problem Areas

The Colab-specific code is commented out at the top, and some of the model weight initialization and flops/mfu/etc calculations might require you to update them manually if you are making significant changes to the network.
The Colab-specific code is commented out at the top, and some of the model weight initialization and flops/mfu/etc calculations might require you to update them manually if you are making significant changes to the network. There's currently some bugs relating to the dataloader and the number of steps we run doing training -- if just manually measuring the time to get to <3.8 loss, you should be okay for now. Hopefully I'll be able to fix this in a future release.

### Why 3.8 Val Loss?

Expand All @@ -64,7 +64,7 @@ If you use this work in your research, please cite
month={3},
title={{hlb-gpt}},
url={https://github.com/tysam-code/hlb-gpt},
version = {0.0.0},
version = {0.1.0},
year = {2023}}`

### Bugs & Etc.
Expand Down
85 changes: 63 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

# Check if we're using pytorch 2 for those speedups
using_pytorch_2 = (int(torch.__version__.split('.')[0]) >= 2)
if not using_pytorch_2:
print("Info: Pytorch 2.0 isn't currently installed. Falling back to slower Pytorch 1.x pathway.")

## <-- teaching comments
# <-- functional comments
Expand Down Expand Up @@ -59,12 +61,12 @@
hyp = {
'opt': {
'lr': 2e-3,
'weight_decay': 1e-3,
'total_train_steps': 900,
'weight_decay': 1.5e-2,
'total_train_steps': 6000,
'eval_iter': 50, # how many train iterations we wait in between eval rounds (we don't include eval time in our performance stats)
'warmup_percent': .15, ## what percent of the training run to warmup the learning rate over
'accumulate_steps': 5*(8//4), # via nanoGPT: simulate 8 gpus, 5 accum steps for a larger batchsize (esp necesary towards the end).
}, # since this model is tiny, we can increase the batchsize and decrease this by a factor of 4
'warmup_percent': .05, ## what percent of the training run to warmup the learning rate over
'initial_accumulate_steps': 1, # It's good for speed to start small (i.e. 1) here and tune target_per_step_decay carefully to grow to the appropriate num of accumulate steps over traiing
},
'net': {
'residual_depth': 384, ## this should be a factor of 8 in some way to stay tensor core friendly
'num_heads': 6,
Expand Down Expand Up @@ -157,8 +159,8 @@ def __init__(self, num_features, sequence_length, num_heads):
self.attention = nn.MultiheadAttention(num_features, num_heads, bias=False, batch_first=True)

## this mask makes sure that each part of a sequence can only attend to the tokens that come behind it.
self.causal_mask = torch.logical_not(torch.triu(torch.ones((sequence_length, sequence_length), device=hyp['misc']['device'], dtype=torch.bool))).T # TODO: way to simplify this?
self.causal_mask = torch.logical_not(torch.triu(torch.ones((sequence_length, sequence_length), device=hyp['misc']['device'], dtype=torch.bool))).T # TODO: way to simplify this? (see: after pytorch 2.0 release, causal=True on the scaled_dot_product_attention fn)

def forward(self, x):
residual = x
x = self.norm(x)
Expand Down Expand Up @@ -229,7 +231,7 @@ def make_net():
network_dict = nn.ModuleDict({
'embedding': nn.Embedding(hyp['misc']['num_tokens'], hyp['net']['residual_depth']),
'position': PositionEmbedding(hyp['misc']['sequence_length'], hyp['net']['residual_depth']),
'norm': LayerNorm(hyp['net']['residual_depth'], eps=1e-5, bias=False),
'norm': LayerNorm(hyp['net']['residual_depth'], bias=False),
'mlp_layers': nn.ModuleList([MLPBlock(hyp['net']['residual_depth']) for _ in range(hyp['net']['num_blocks'])]),
'attn_layers': nn.ModuleList([AttentionBlock(hyp['net']['residual_depth'], hyp['misc']['sequence_length'], hyp['net']['num_heads']) for _ in range(hyp['net']['num_blocks'])]),
'outputs': nn.Linear(hyp['net']['residual_depth'], hyp['misc']['num_tokens'], bias=False),
Expand All @@ -243,7 +245,7 @@ def make_net():
net.net_dict['embedding'].weight = net.net_dict['outputs'].weight

for name, parameter in net.named_parameters():
# TODO: Way to tidy this up for a future release?
# TODO: Way to tidy this up for a future release? (once pytorch 2.0 releases we can use the scaled_dot_product attention, update the names appropriately, and point to an older release for people using PT <2.0)
# Initialize both embedding layers (embedding and position) and the non-bias values of the 'normal' linear layers (outputs, expand, in_proj)
if 'embedding' in name or 'position' in name or (('outputs' in name or 'expand' in name or 'in_proj' in name) and 'weight' in name):
torch.nn.init.normal_(parameter.data, mean=0., std=.02) # normal init
Expand Down Expand Up @@ -326,12 +328,23 @@ def init_split_parameter_dictionaries(net):

return params_non_decay, params_decay

@torch.compile
def get_grad_norm(net):
# Gets the entire grad norm of the network.
grad_norm = torch.tensor(0., device=hyp['misc']['device'])
for p in net.parameters():
if p.grad is not None:
param_norm = p.grad.detach().data.norm(2)
grad_norm += param_norm.square()
grad_norm = (grad_norm ** 0.5).item()
return grad_norm


## Just your good ol', normal an' nice xentropy function. Which makes sense if (in the ideal scenario) we only see each datapoint one single time.
## However! If (esp for tiny datsets) we're seeing our data multiple times in a row, then maybe some smoothing to help regularize things a bit is in order.... :D
loss_fn = nn.CrossEntropyLoss(reduction='mean', ignore_index=-1)

logging_columns_list = ['epoch', 'current_steps', 'train_loss', 'val_loss', 'val_perplexity', 'train_acc', 'val_acc', 'a100_mfu', 'total_time_seconds']
logging_columns_list = ['epoch', 'current_steps', 'train_loss', 'val_loss', 'val_perplexity', 'train_acc', 'val_acc', 'grad_norm', 'a100_mfu', 'total_time_seconds']
# define the printing function and print the column heads
def print_training_details(columns_list, separator_left='| ', separator_right=' ', final="|", column_heads_only=False, is_final_entry=False):
print_string = ""
Expand Down Expand Up @@ -388,13 +401,21 @@ def eval(net):
return val_acc, val_loss, val_perplexity

def main():
# Initializing constants for the whole run.
# Initializing variables for the whole run.
total_time_seconds = 0.
current_steps = 0
#train_loss = 10. # for the updating terminal printout, initialized to roughly the initial loss.
grad_norm = 0. # initialize the grad norm calculation
microbatches_since_last_eval = 0. # TODO: Good way to simplify this?
running_grad_norm_decay = .95
target_per_step_decay = 3e-2 # what absolute step size we should target each training step. the effective batchsize is scaled to try to meet this target. :)
accumulate_steps_lr = 5e-2 # smooths out the automatic batchsize scaling rate
running_grad_norm = 1.2 # initialized roughly to what the initial grad norm is. the ema that we use for tracking our grad norm over time
current_accumulate_steps = accumulate_steps_estimate = hyp['opt']['initial_accumulate_steps'] # current_accumulate_steps is the per-microbatch sampled steps, accumulate_steps_estimate is the actual estimated fractional value determining our projected batchsize

num_steps_per_epoch = len(data['train']) // (batchsize * hyp['misc']['sequence_length'])
# Note: This is a static calculation of the total number of microbatches up front, you may have to change this depending upon what you're tinkering with
total_microbatch_steps = hyp['opt']['total_train_steps'] * hyp['opt']['accumulate_steps']
total_microbatch_steps = hyp['opt']['total_train_steps'] * hyp['opt']['initial_accumulate_steps'] # BUG: Since we have dynamic virtual batchsize scaling now, we're going to have to rewrite the dataloader to appropriately handle it now.


# Get network
Expand All @@ -404,7 +425,7 @@ def main():
params_non_decay, params_decay = init_split_parameter_dictionaries(net)
adamw_speedup_mode = {'fused': True} if using_pytorch_2 else {'foreach': True}
opt = torch.optim.AdamW([params_non_decay, params_decay], **adamw_speedup_mode)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, max_lr=hyp['opt']['lr'], total_steps=hyp['opt']['total_train_steps'], pct_start=hyp['opt']['warmup_percent'], anneal_strategy='cos', cycle_momentum=False, div_factor=1e2, final_div_factor=.05)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, max_lr=hyp['opt']['lr'], total_steps=hyp['opt']['total_train_steps'], pct_start=hyp['opt']['warmup_percent'], anneal_strategy='linear', cycle_momentum=False, div_factor=1e2, final_div_factor=.02)

## For accurately timing GPU code
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
Expand Down Expand Up @@ -435,41 +456,61 @@ def main():
loss = loss_fn(outputs.flatten(0, 1), targets.flatten(0, 1))

# Quick non-eval summary every N training steps
if current_steps % 10 == 0 and microbatch_step % hyp['opt']['accumulate_steps'] == 0 and not current_steps % hyp['opt']['eval_iter'] == 0:
if current_steps % 10 == 0 and microbatch_step % current_accumulate_steps == 0 and not current_steps % hyp['opt']['eval_iter'] == 0:
train_acc = (outputs.detach().argmax(-1) == targets).float().mean().item()
train_loss = loss.detach().cpu().item()
train_summary_variables = {'epoch': microbatch_step//num_steps_per_epoch, 'current_steps': current_steps, 'train_loss': train_loss, 'train_acc': train_acc}
train_summary_variables = {'epoch': microbatch_step//num_steps_per_epoch, 'current_steps': current_steps, 'train_loss': train_loss, 'train_acc': train_acc, 'grad_norm': grad_norm}
print_training_details(list(map(partial(format_for_table, locals=train_summary_variables), logging_columns_list)))

loss.backward()
loss.div(current_accumulate_steps).backward()
microbatches_since_last_eval += 1

## Once we've accumulated steps over all of our microbatches, take a single full-batchsize step.
if microbatch_step % hyp['opt']['accumulate_steps'] == 0:
if microbatch_step % current_accumulate_steps == 0:
## Step the optimizer, then scheduler
opt.step()

# Dynamic weight decay scheduling. Based upon the squared log likelihood of the data [inspired by section 5 of https://arxiv.org/pdf/2204.02311.pdf]
# (up to its max value at likelihood = 1, which we should in all...likelihood...never reach. :')))) )
# Still evaluating the top-end of this option vs a few other options out there.
opt.param_groups[1]['weight_decay'] = (1./loss.detach().item())**2. * hyp['opt']['weight_decay']
scheduler.step()

# The next several lines calculate a dynamic batchsize, simulated through manual dithering
# There could be improvements or losses in changing the dithering strategy, since determinism and gradient descent can lead to some very not-so-nice (and subtle) loss oscillations.
# First, manually calculate the grad norm here (no clipping or anything)
grad_norm = get_grad_norm(net) # TODO: Can/should we evaluate every N steps instead?

running_grad_norm = running_grad_norm_decay * running_grad_norm + (1. - running_grad_norm_decay) * grad_norm

per_step_diff_delta = target_per_step_decay - (running_grad_norm - grad_norm)
# Scale the learning rate by the current number of accumulate steps so we're able to be nimble even if steps take a very long time
accumulate_steps_estimate += current_accumulate_steps * (accumulate_steps_lr * per_step_diff_delta)
# Clamp our fractional accumulate steps estimate so it doesn't go below 1
accumulate_steps_estimate = max(1., accumulate_steps_estimate)
base, probability = divmod(accumulate_steps_estimate, 1)
# Randomly sample next accumulate steps to use
current_accumulate_steps = max(1, int(base + torch.bernoulli(torch.tensor(probability)).item())) # bernoulli via torch to save an unnecesary import :)

## Using 'set_to_none' I believe is slightly faster (albeit riskier w/ funky gradient update workflows) than under the default 'set to zero' method
opt.zero_grad(set_to_none=True)
current_steps += 1

# Since we're not running over epochs anymore, we have to manually calculate what epoch it is.
epoch = microbatch_step//num_steps_per_epoch

if current_steps % hyp['opt']['eval_iter'] == 0:
ender.record()
torch.cuda.synchronize()
total_time_seconds += 1e-3 * starter.elapsed_time(ender)
train_loss = loss.detach().cpu().item() # To have an updated loss to compare with the eval loss

opt.zero_grad(set_to_none=True)
net.eval()

val_acc, val_loss, val_perplexity = eval(net)
average_time_per_batch = 1e-3 * starter.elapsed_time(ender)/hyp['opt']['eval_iter']

a100_mfu, _ = get_net_mfu_and_param_counts(net, batchsize, hyp['opt']['accumulate_steps'], avg_time_per_batch=average_time_per_batch)
a100_mfu, _ = get_net_mfu_and_param_counts(net, batchsize, microbatches_since_last_eval/hyp['opt']['eval_iter'], avg_time_per_batch=average_time_per_batch)
microbatches_since_last_eval = 0 # necessary for accurate mfu counts. How totally necessary is mfu here if we're mainly using wallclock time?
is_final_eval = (current_steps == hyp['opt']['total_train_steps']) # If we're at the end of training, do a full eval instead

# Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....)
Expand Down

0 comments on commit 4846ccd

Please sign in to comment.