Skip to content

Commit 8a77b05

Browse files
committed
.
1 parent 8c80a0d commit 8a77b05

File tree

1 file changed

+52
-24
lines changed

1 file changed

+52
-24
lines changed

tinker_cookbook/rl/train.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""
22
Implements RL on general MDPs
3-
43
"""
54

65
import asyncio
@@ -55,6 +54,9 @@ def _select_representative_inds(scores: list[float], num_inds: int) -> list[int]
5554

5655
@scope
5756
def print_group(traj_group: TrajectoryGroup, tokenizer: Tokenizer):
57+
"""
58+
Print a subset of the trajectory group to the console.
59+
"""
5860
# Cut down the number of trajectories to print
5961
max_trajs_to_print = 4
6062
if len(traj_group.trajectories_G) > max_trajs_to_print:
@@ -252,7 +254,7 @@ async def do_sync_training_with_stream_minibatch(
252254

253255
# Initial sampling client
254256
sampling_client, _ = await save_checkpoint_and_get_sampling_client(
255-
cfg, training_client, start_batch
257+
training_client, start_batch, cfg.log_path, cfg.save_every
256258
)
257259

258260
for i_batch in range(start_batch, end_batch):
@@ -279,7 +281,10 @@ async def trajectory_group_worker_task(builder: EnvGroupBuilder) -> None:
279281
metrics = {}
280282
t_start = time.time()
281283
trajectory_group = await do_group_rollout_and_filter_constant_reward(
282-
cfg, sampling_client, builder
284+
sampling_client,
285+
builder,
286+
max_tokens=cfg.max_tokens,
287+
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
283288
)
284289
metrics["time/trajectory_group_worker_loop/total"] = time.time() - t_start
285290
if trajectory_group is not None:
@@ -407,9 +412,10 @@ async def trajectory_group_worker_loop():
407412
# while we're running the rollout
408413
sampling_client_step_copy = sampling_client_step
409414
trajectory_group = await do_group_rollout_and_filter_constant_reward(
410-
cfg,
411415
sampling_client,
412416
env_group_builder,
417+
max_tokens=cfg.max_tokens,
418+
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
413419
)
414420
if trajectory_group is None:
415421
trajectory_groups_queue.put_nowait(None)
@@ -564,16 +570,17 @@ async def evaluation_loop():
564570

565571
@scope
566572
async def do_group_rollout_and_filter_constant_reward(
567-
cfg: Config,
568573
sampling_client: tinker.SamplingClient,
569574
env_group_builder: EnvGroupBuilder,
575+
max_tokens: int,
576+
do_remove_constant_reward_groups: bool,
570577
) -> TrajectoryGroup | None:
571-
policy = TinkerTokenCompleter(sampling_client, max_tokens=cfg.max_tokens)
578+
policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens)
572579
trajectory_group = await do_group_rollout(env_group_builder, policy)
573580

574581
# Remove if all trajectories have the same reward
575582
trajectory_groups = [trajectory_group]
576-
if cfg.remove_constant_reward_groups:
583+
if do_remove_constant_reward_groups:
577584
trajectory_groups = remove_constant_reward_groups(trajectory_groups)
578585
if len(trajectory_groups) == 0:
579586
return None
@@ -582,29 +589,32 @@ async def do_group_rollout_and_filter_constant_reward(
582589

583590
@scope
584591
async def save_checkpoint_and_get_sampling_client(
585-
cfg: Config,
586592
training_client: tinker.TrainingClient,
587593
i_batch: int,
594+
log_path: str,
595+
save_every: int,
588596
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
589597
metrics = {}
590598
with timed("save_checkpoint", metrics):
591599
path_dict = await checkpoint_utils.save_checkpoint_async(
592600
training_client=training_client,
593601
name=f"{i_batch:06d}",
594-
log_path=cfg.log_path,
602+
log_path=log_path,
595603
loop_state={"batch": i_batch},
596-
kind="both" if (i_batch > 0 and i_batch % cfg.save_every == 0) else "sampler",
604+
kind="both" if (i_batch > 0 and i_batch % save_every == 0) else "sampler",
597605
)
598606
return training_client.create_sampling_client(path_dict["sampler_path"]), metrics
599607

600608

601609
@scope
602610
async def prepare_minibatch(
603-
cfg: Config,
604611
env_group_builders_P: Sequence[EnvGroupBuilder],
605612
trajectory_groups_P: list[TrajectoryGroup],
606613
tokenizer: Tokenizer,
607614
service_client: tinker.ServiceClient,
615+
model_name: str,
616+
kl_penalty_coef: float,
617+
kl_discount_factor: float,
608618
) -> tuple[list[tinker.Datum], dict[str, Any]]:
609619
"""Converts the trajectories into a minibatch, and provides metrics about the minibatch"""
610620

@@ -623,14 +633,14 @@ async def prepare_minibatch(
623633
data_D, _metadata_D = assemble_training_data(trajectory_groups_P, advantages_P)
624634

625635
# Incorporate KL penalty if configured
626-
if cfg.kl_penalty_coef > 0:
636+
if kl_penalty_coef > 0:
627637
with timed("kl_vs_base", metrics):
628638
kl_penalty_metrics = await incorporate_kl_penalty(
629639
data_D,
630-
service_client.create_sampling_client(base_model=cfg.model_name),
640+
service_client.create_sampling_client(base_model=model_name),
631641
# ^^^ TODO: replace with the model we load, if relevant
632-
cfg.kl_penalty_coef,
633-
cfg.kl_discount_factor,
642+
kl_penalty_coef,
643+
kl_discount_factor,
634644
)
635645
metrics.update(kl_penalty_metrics)
636646

@@ -639,15 +649,20 @@ async def prepare_minibatch(
639649

640650
@scope
641651
async def compute_full_batch_metrics_and_get_sampling_client(
642-
cfg: Config,
643652
training_client: tinker.TrainingClient,
644653
i_batch: int,
645654
data_D: list[tinker.Datum],
646655
training_logprobs_D: list[torch.Tensor],
656+
log_path: str,
657+
save_every: int,
658+
do_compute_post_kl: bool,
647659
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
648660
"""
649661
At the end of the iteration, this will compute metrics for the full batch
650662
and return the latest sampling client.
663+
664+
The reason we return a sampling client is that if do_compute_post_kl is True,
665+
we need to create a sampling client from the post-update policy.
651666
"""
652667
metrics = {}
653668

@@ -658,12 +673,12 @@ async def compute_full_batch_metrics_and_get_sampling_client(
658673

659674
# Get a sampling client using the new weights
660675
sampling_client, checkpoint_metrics = await save_checkpoint_and_get_sampling_client(
661-
cfg, training_client, i_batch
676+
training_client, i_batch, log_path, save_every
662677
)
663678
metrics.update(checkpoint_metrics)
664679

665680
# Compute post-KL metrics if configured
666-
if cfg.compute_post_kl:
681+
if do_compute_post_kl:
667682
with timed("compute_post_kl", metrics):
668683
post_kl_metrics = await compute_post_kl(data_D, sampling_client)
669684
metrics.update(post_kl_metrics)
@@ -728,11 +743,13 @@ async def do_train_step_streaming_and_get_sampling_client(
728743
# remove these and train on a smaller batch.
729744
wrapped_trajectory_groups = [g for g in wrapped_trajectory_groups if g is not None]
730745
data_D, prepare_minibatch_metrics = await prepare_minibatch(
731-
cfg,
732746
[g.env_group_builder for g in wrapped_trajectory_groups],
733747
[g.trajectory_group for g in wrapped_trajectory_groups],
734748
tokenizer,
735749
service_client,
750+
model_name=cfg.model_name,
751+
kl_penalty_coef=cfg.kl_penalty_coef,
752+
kl_discount_factor=cfg.kl_discount_factor,
736753
)
737754
metrics.update(prepare_minibatch_metrics)
738755

@@ -767,12 +784,14 @@ async def do_train_step_streaming_and_get_sampling_client(
767784
sampling_client,
768785
full_batch_metrics,
769786
) = await compute_full_batch_metrics_and_get_sampling_client(
770-
cfg,
771787
training_client,
772788
# NOTE: saving the checkpoint as the i + 1 step
773789
i_batch + 1,
774790
all_data_D,
775791
all_training_logprobs_D,
792+
cfg.log_path,
793+
cfg.save_every,
794+
cfg.compute_post_kl,
776795
)
777796
metrics.update(full_batch_metrics)
778797
return sampling_client, metrics
@@ -793,11 +812,13 @@ async def do_train_step_and_get_sampling_client(
793812

794813
metrics = {}
795814
data_D, prepare_minibatch_metrics = await prepare_minibatch(
796-
cfg,
797815
env_group_builders_P,
798816
trajectory_groups_P,
799817
tokenizer,
800818
service_client,
819+
model_name=cfg.model_name,
820+
kl_penalty_coef=cfg.kl_penalty_coef,
821+
kl_discount_factor=cfg.kl_discount_factor,
801822
)
802823
metrics.update(prepare_minibatch_metrics)
803824

@@ -811,12 +832,14 @@ async def do_train_step_and_get_sampling_client(
811832
)
812833

813834
sampling_client, full_batch_metrics = await compute_full_batch_metrics_and_get_sampling_client(
814-
cfg,
815835
training_client,
816836
# NOTE: saving the checkpoint as the i + 1 step
817837
i_batch + 1,
818838
data_D,
819839
training_logprobs_D,
840+
cfg.log_path,
841+
cfg.save_every,
842+
cfg.compute_post_kl,
820843
)
821844
metrics.update(full_batch_metrics)
822845

@@ -840,7 +863,7 @@ async def do_sync_training(
840863

841864
# Initial sampling client
842865
sampling_client, _ = await save_checkpoint_and_get_sampling_client(
843-
cfg, training_client, start_batch
866+
training_client, start_batch, cfg.log_path, cfg.save_every
844867
)
845868

846869
for i_batch in range(start_batch, end_batch):
@@ -864,7 +887,12 @@ async def do_sync_training(
864887
trajectory_groups_P = await asyncio.gather(
865888
*[
866889
asyncio.create_task(
867-
do_group_rollout_and_filter_constant_reward(cfg, sampling_client, builder),
890+
do_group_rollout_and_filter_constant_reward(
891+
sampling_client,
892+
builder,
893+
max_tokens=cfg.max_tokens,
894+
do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,
895+
),
868896
name=f"sample_task_{i}",
869897
)
870898
for i, builder in enumerate(env_group_builders_P)

0 commit comments

Comments
 (0)