diff --git a/notebooks/02_quick_debugs.ipynb b/notebooks/02_quick_debugs.ipynb index f2bd812..0a5366f 100644 --- a/notebooks/02_quick_debugs.ipynb +++ b/notebooks/02_quick_debugs.ipynb @@ -2,9 +2,20 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "100.11712" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import torch\n", "\n", @@ -17,6 +28,13 @@ "# n params\n", "sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -35,7 +53,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/torchrun_main.py b/torchrun_main.py index 0ccbf2b..cc70c17 100644 --- a/torchrun_main.py +++ b/torchrun_main.py @@ -728,6 +728,7 @@ def main(args): train_dataset = datasets.distributed.split_dataset_by_node(train_dataset, rank=global_rank, world_size=world_size) eval_dataset = datasets.distributed.split_dataset_by_node(eval_dataset, rank=global_rank, world_size=world_size) + logger.info(f"Skipping the first {global_step} batches") train_loader = SkipDataLoader( train_dataset, batch_size=args.batch_size, @@ -868,12 +869,12 @@ def main(args): # MERGE AND REINIT # restart model after we modify the learning rate, so on the next step after the relora frequency - can_reset = args.relora is not None and ( + can_reset_relora = args.relora is not None and ( args.resume_from is not None or local_step * args.gradient_accumulation > args.relora ) - if can_reset and update_step % args.relora == 1: + if can_reset_relora and update_step % args.relora == 1: logger.info(f"{args.resume_from=}, {local_step=}, {args.relora=}, prod: {local_step * args.gradient_accumulation}") logger.info(f"Performing lora reset at update step {update_step}. Current lr is {optimizer.param_groups[0]['lr']}") n_lora_restarts += 1 @@ -885,7 +886,12 @@ def main(args): else: raise ValueError(f"Unknown distributed type {args.distributed_type}") - if can_reset and update_step % args.cycle_length == 1: + can_reset_optimizer = args.relora is not None and ( + args.resume_from is not None + or local_step * args.gradient_accumulation > args.cycle_length + ) + + if can_reset_optimizer and update_step % args.cycle_length == 1: # scheduler should provide a new warmup after the reset training_utils.check_lr_and_alert(optimizer, max_lr=1e-4) @@ -900,8 +906,8 @@ def main(args): ) # ############################## - if can_reset and update_step % args.relora == 2: - logger.info(f"First step after lora reset lr is {optimizer.param_groups[0]['lr']}") + if can_reset_optimizer and update_step % args.cycle_length == 2: + logger.info(f"First step after optimizer reset lr is {optimizer.param_groups[0]['lr']}") lr = optimizer.param_groups[0]["lr"] tokens_in_update = tokens_seen - tokens_seen_before diff --git a/training_configs/1B_v1.0.yaml b/training_configs/1B_v1.0.yaml index 4e2696f..7dab0fe 100644 --- a/training_configs/1B_v1.0.yaml +++ b/training_configs/1B_v1.0.yaml @@ -5,36 +5,37 @@ workers: 8 # model model_name_or_path: EleutherAI/pythia-1b -model_revision: step10000 +model_revision: step1000 # saving -save_dir: checkpoints/relora_1b_Aug4_2023_run +save_dir: checkpoints/relora_1b_Aug5_2023_run2 autoresume: true # ReLoRA use_peft: true force_keep_original: true lora_r: 128 -relora: 5320 +relora: 1000 restart_warmup_steps: 100 reset_optimizer_on_relora: false -optimizer_magnitude_pruning: 0.9 +optimizer_magnitude_pruning: 0.8 # Optimization -optimizer: adam_zero -batch_size: 64 +optimizer: adam +batch_size: 8 total_batch_size: 1024 lr: 4e-4 adam_beta1: 0.9 adam_beta2: 0.95 weight_decay: 0.01 scheduler: cosine_restarts -warmup_steps: 13_000 -num_training_steps: 133_000 -eval_every: 1000 -save_every: 1000 +warmup_steps: 500 # used to be 13_000, but reduced it to comply with the scheduler +num_training_steps: 130_000 # used to be 133_000, but it's an ugly number +eval_every: 500 +save_every: 500 # Misc dtype: bfloat16 distributed_type: ddp -tags: relora1b +tags: relora1b_debug +comment: "Checking if ReLoRA 1B loss is similar to regular training loss overnight"