Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2c514a5
Now, we get num_attention_heads from the hf config.
finbarrtimbers Oct 21, 2025
4e9b772
Update code
finbarrtimbers Oct 21, 2025
4375c59
Added test that we match manual values
finbarrtimbers Oct 21, 2025
633310b
Updated calculations
finbarrtimbers Oct 21, 2025
72abb5c
Updated code with check_calculation
finbarrtimbers Oct 23, 2025
bea2537
Updated code
finbarrtimbers Oct 24, 2025
29514ad
Now, tests pass.
finbarrtimbers Oct 28, 2025
66170b8
Updated code to normalize properly
finbarrtimbers Oct 28, 2025
9b96a01
Added some fixes
finbarrtimbers Oct 29, 2025
1e1f559
Updated code
finbarrtimbers Oct 29, 2025
1090a27
Updated code
finbarrtimbers Oct 29, 2025
3c6c3fd
Another fix
finbarrtimbers Oct 29, 2025
4d19c57
Cleaned up tests.
finbarrtimbers Oct 29, 2025
bc572cf
Cleaned up PR
finbarrtimbers Oct 29, 2025
2237da3
Update MFU/MBU code.
finbarrtimbers Oct 30, 2025
387d12d
Now, mbu tests pass.
finbarrtimbers Oct 30, 2025
13c7d94
Moved to json file
finbarrtimbers Oct 30, 2025
3e43c7f
Added test data
finbarrtimbers Oct 30, 2025
696e090
undid changes and simplified test function.
finbarrtimbers Oct 30, 2025
db89ec8
An attempt at a fix
finbarrtimbers Oct 30, 2025
073ab31
Update code with patches
finbarrtimbers Oct 30, 2025
e4a74a4
now, tests pass
finbarrtimbers Oct 30, 2025
bed21f1
Added MFU to DPO
finbarrtimbers Oct 31, 2025
9bff1a9
updated script
finbarrtimbers Oct 31, 2025
1b667b1
uses uv for dpo
finbarrtimbers Oct 31, 2025
31fe56d
Added a chat template to the DPO script.
finbarrtimbers Oct 31, 2025
efb5f06
Added trackign
finbarrtimbers Oct 31, 2025
81823fe
Updated code to handle tracking when none
finbarrtimbers Oct 31, 2025
8fe6c1b
Added description updates
finbarrtimbers Oct 31, 2025
7233c2e
undid changes
finbarrtimbers Oct 31, 2025
ff3aab3
Check out dpo script
finbarrtimbers Oct 31, 2025
cfc0f94
updated script
finbarrtimbers Oct 31, 2025
3217640
Update code to remove whitespace
finbarrtimbers Oct 31, 2025
29ab5f9
fix finetune timing
finbarrtimbers Oct 31, 2025
cb4caca
Merge branch 'main' into wandb-descriptions
finbarrtimbers Oct 31, 2025
2edde8e
Fixed bugs pointed out by cursor.
finbarrtimbers Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion open_instruct/dpo_tune_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
is_beaker_job,
launch_ai2_evals_on_weka,
maybe_get_beaker_config,
maybe_update_beaker_description,
maybe_use_ai2_hf_entity,
maybe_use_ai2_wandb_entity,
)
Expand Down Expand Up @@ -498,6 +499,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
},
)
wandb_tracker = accelerator.get_tracker("wandb")
maybe_update_beaker_description(wandb_url=wandb_tracker.run.get_url() if args.with_tracking else None)

if accelerator.is_main_process:
pprint([args, tc])
Expand Down Expand Up @@ -813,6 +815,7 @@ def load_model():
print("=============after cache logprobs; clear cache")
print_gpu_stats(init_gpu_memory)
# Only show the progress bar once on each machine.
start_time = time.perf_counter()
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
# update the progress_bar if load from checkpoint
progress_bar.update(completed_steps)
Expand Down Expand Up @@ -936,6 +939,12 @@ def load_model():
logger.info(logger_str)
if args.with_tracking:
accelerator.log(metrics_to_log, step=completed_steps)
maybe_update_beaker_description(
current_step=completed_steps,
total_steps=args.max_train_steps,
start_time=start_time,
wandb_url=wandb_tracker.run.get_url() if args.with_tracking else None,
)
# Reset the local metrics
local_metrics.zero_()

Expand Down Expand Up @@ -989,7 +998,7 @@ def load_model():
path=args.output_dir,
leaderboard_name=args.hf_repo_revision,
oe_eval_max_length=args.oe_eval_max_length,
wandb_url=wandb_tracker.run.get_url(),
wandb_url=wandb_tracker.run.get_url() if args.with_tracking else None,
oe_eval_tasks=args.oe_eval_tasks,
gs_bucket_path=args.gs_bucket_path,
)
Expand Down
22 changes: 16 additions & 6 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
is_beaker_job,
launch_ai2_evals_on_weka,
maybe_get_beaker_config,
maybe_update_beaker_description,
maybe_use_ai2_hf_entity,
maybe_use_ai2_wandb_entity,
)
Expand Down Expand Up @@ -438,6 +439,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
},
)
wandb_tracker = accelerator.get_tracker("wandb")
maybe_update_beaker_description(wandb_url=wandb_tracker.run.get_url())
else:
wandb_tracker = None # for later eval launching

Expand Down Expand Up @@ -727,7 +729,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
local_total_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
local_pred_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
start_time = time.time()
start_time = time.perf_counter()
skipped_batches = False
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
Expand Down Expand Up @@ -824,10 +826,12 @@ def main(args: FlatArguments, tc: TokenizerConfig):
"avg_tokens_per_batch": avg_tokens_per_batch,
"avg_tokens_per_batch_including_padding": avg_tokens_per_batch_including_padding,
"avg_pred_tokens_per_batch": avg_pred_tokens_per_batch,
"per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time),
"per_device_tps": total_tokens
/ accelerator.num_processes
/ (time.perf_counter() - start_time),
"per_device_tps_including_padding": total_tokens_including_padding
/ accelerator.num_processes
/ (time.time() - start_time),
/ (time.perf_counter() - start_time),
"reserved_mem_GiB": torch.cuda.max_memory_reserved(device=torch.cuda.current_device()) / 2**30,
"allocated_mem_GiB": torch.cuda.max_memory_allocated(device=torch.cuda.current_device())
/ 2**30,
Expand Down Expand Up @@ -855,7 +859,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
avg_loss = sum_loss / total_fwd_passes
metrics_to_log["train_loss"] = avg_loss
if args.verbose:
sec_per_step = (time.time() - start_time) / (completed_steps - resume_step)
sec_per_step = (time.perf_counter() - start_time) / (completed_steps - resume_step)
steps_remaining = args.max_train_steps - completed_steps
secs_remaining = steps_remaining * sec_per_step
accelerator.print(
Expand All @@ -869,17 +873,23 @@ def main(args: FlatArguments, tc: TokenizerConfig):
/ args.logging_steps
)
logger.info(
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, Aux Loss: {avg_aux_loss}, TPS: {total_tokens / (time.time() - start_time)}"
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, Aux Loss: {avg_aux_loss}, TPS: {total_tokens / (time.perf_counter() - start_time)}"
)
metrics_to_log["aux_loss"] = avg_aux_loss
else:
logger.info(
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}"
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.perf_counter() - start_time)}"
)
if args.verbose:
accelerator.print(f"{metrics_to_log=}")
if args.with_tracking:
accelerator.log(metrics_to_log, step=completed_steps)
maybe_update_beaker_description(
current_step=completed_steps,
total_steps=args.max_train_steps,
start_time=start_time,
wandb_url=wandb_tracker.run.get_url() if wandb_tracker is not None else None,
)
total_loss = 0
total_aux_loss = 0

Expand Down
15 changes: 12 additions & 3 deletions scripts/train/debug/dpo.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
python mason.py \
#!/bin/bash
BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}"

uv run python mason.py \
--cluster ai2/neptune \
--cluster ai2/saturn \
--cluster ai2/jupiter \
--cluster ai2/prior \
--description "Single GPU DPO run, for debugging purposes." \
--workspace ai2/tulu-thinker \
--priority high \
--image nathanl/open_instruct_auto --pure_docker_mode \
--image "$BEAKER_IMAGE" \
--pure_docker_mode \
--preemptible \
--num_nodes 1 \
--budget ai2/oe-adapt \
Expand All @@ -26,5 +34,6 @@ python mason.py \
--logging_steps 1 \
--dataset_mixer_list allenai/tulu-3-wildchat-reused-on-policy-8b 100 \
--add_bos \
--chat_template_name olmo \
--seed 123
# --with_tracking
# --with_tracking