11"""
22Implements RL on general MDPs
3-
43"""
54
65import asyncio
@@ -55,6 +54,9 @@ def _select_representative_inds(scores: list[float], num_inds: int) -> list[int]
5554
5655@scope
5756def print_group (traj_group : TrajectoryGroup , tokenizer : Tokenizer ):
57+ """
58+ Print a subset of the trajectory group to the console.
59+ """
5860 # Cut down the number of trajectories to print
5961 max_trajs_to_print = 4
6062 if len (traj_group .trajectories_G ) > max_trajs_to_print :
@@ -252,7 +254,7 @@ async def do_sync_training_with_stream_minibatch(
252254
253255 # Initial sampling client
254256 sampling_client , _ = await save_checkpoint_and_get_sampling_client (
255- cfg , training_client , start_batch
257+ training_client , start_batch , cfg . log_path , cfg . save_every
256258 )
257259
258260 for i_batch in range (start_batch , end_batch ):
@@ -279,7 +281,10 @@ async def trajectory_group_worker_task(builder: EnvGroupBuilder) -> None:
279281 metrics = {}
280282 t_start = time .time ()
281283 trajectory_group = await do_group_rollout_and_filter_constant_reward (
282- cfg , sampling_client , builder
284+ sampling_client ,
285+ builder ,
286+ max_tokens = cfg .max_tokens ,
287+ do_remove_constant_reward_groups = cfg .remove_constant_reward_groups ,
283288 )
284289 metrics ["time/trajectory_group_worker_loop/total" ] = time .time () - t_start
285290 if trajectory_group is not None :
@@ -407,9 +412,10 @@ async def trajectory_group_worker_loop():
407412 # while we're running the rollout
408413 sampling_client_step_copy = sampling_client_step
409414 trajectory_group = await do_group_rollout_and_filter_constant_reward (
410- cfg ,
411415 sampling_client ,
412416 env_group_builder ,
417+ max_tokens = cfg .max_tokens ,
418+ do_remove_constant_reward_groups = cfg .remove_constant_reward_groups ,
413419 )
414420 if trajectory_group is None :
415421 trajectory_groups_queue .put_nowait (None )
@@ -564,16 +570,17 @@ async def evaluation_loop():
564570
565571@scope
566572async def do_group_rollout_and_filter_constant_reward (
567- cfg : Config ,
568573 sampling_client : tinker .SamplingClient ,
569574 env_group_builder : EnvGroupBuilder ,
575+ max_tokens : int ,
576+ do_remove_constant_reward_groups : bool ,
570577) -> TrajectoryGroup | None :
571- policy = TinkerTokenCompleter (sampling_client , max_tokens = cfg . max_tokens )
578+ policy = TinkerTokenCompleter (sampling_client , max_tokens = max_tokens )
572579 trajectory_group = await do_group_rollout (env_group_builder , policy )
573580
574581 # Remove if all trajectories have the same reward
575582 trajectory_groups = [trajectory_group ]
576- if cfg . remove_constant_reward_groups :
583+ if do_remove_constant_reward_groups :
577584 trajectory_groups = remove_constant_reward_groups (trajectory_groups )
578585 if len (trajectory_groups ) == 0 :
579586 return None
@@ -582,29 +589,32 @@ async def do_group_rollout_and_filter_constant_reward(
582589
583590@scope
584591async def save_checkpoint_and_get_sampling_client (
585- cfg : Config ,
586592 training_client : tinker .TrainingClient ,
587593 i_batch : int ,
594+ log_path : str ,
595+ save_every : int ,
588596) -> tuple [tinker .SamplingClient , dict [str , Any ]]:
589597 metrics = {}
590598 with timed ("save_checkpoint" , metrics ):
591599 path_dict = await checkpoint_utils .save_checkpoint_async (
592600 training_client = training_client ,
593601 name = f"{ i_batch :06d} " ,
594- log_path = cfg . log_path ,
602+ log_path = log_path ,
595603 loop_state = {"batch" : i_batch },
596- kind = "both" if (i_batch > 0 and i_batch % cfg . save_every == 0 ) else "sampler" ,
604+ kind = "both" if (i_batch > 0 and i_batch % save_every == 0 ) else "sampler" ,
597605 )
598606 return training_client .create_sampling_client (path_dict ["sampler_path" ]), metrics
599607
600608
601609@scope
602610async def prepare_minibatch (
603- cfg : Config ,
604611 env_group_builders_P : Sequence [EnvGroupBuilder ],
605612 trajectory_groups_P : list [TrajectoryGroup ],
606613 tokenizer : Tokenizer ,
607614 service_client : tinker .ServiceClient ,
615+ model_name : str ,
616+ kl_penalty_coef : float ,
617+ kl_discount_factor : float ,
608618) -> tuple [list [tinker .Datum ], dict [str , Any ]]:
609619 """Converts the trajectories into a minibatch, and provides metrics about the minibatch"""
610620
@@ -623,14 +633,14 @@ async def prepare_minibatch(
623633 data_D , _metadata_D = assemble_training_data (trajectory_groups_P , advantages_P )
624634
625635 # Incorporate KL penalty if configured
626- if cfg . kl_penalty_coef > 0 :
636+ if kl_penalty_coef > 0 :
627637 with timed ("kl_vs_base" , metrics ):
628638 kl_penalty_metrics = await incorporate_kl_penalty (
629639 data_D ,
630- service_client .create_sampling_client (base_model = cfg . model_name ),
640+ service_client .create_sampling_client (base_model = model_name ),
631641 # ^^^ TODO: replace with the model we load, if relevant
632- cfg . kl_penalty_coef ,
633- cfg . kl_discount_factor ,
642+ kl_penalty_coef ,
643+ kl_discount_factor ,
634644 )
635645 metrics .update (kl_penalty_metrics )
636646
@@ -639,15 +649,20 @@ async def prepare_minibatch(
639649
640650@scope
641651async def compute_full_batch_metrics_and_get_sampling_client (
642- cfg : Config ,
643652 training_client : tinker .TrainingClient ,
644653 i_batch : int ,
645654 data_D : list [tinker .Datum ],
646655 training_logprobs_D : list [torch .Tensor ],
656+ log_path : str ,
657+ save_every : int ,
658+ do_compute_post_kl : bool ,
647659) -> tuple [tinker .SamplingClient , dict [str , Any ]]:
648660 """
649661 At the end of the iteration, this will compute metrics for the full batch
650662 and return the latest sampling client.
663+
664+ The reason we return a sampling client is that if do_compute_post_kl is True,
665+ we need to create a sampling client from the post-update policy.
651666 """
652667 metrics = {}
653668
@@ -658,12 +673,12 @@ async def compute_full_batch_metrics_and_get_sampling_client(
658673
659674 # Get a sampling client using the new weights
660675 sampling_client , checkpoint_metrics = await save_checkpoint_and_get_sampling_client (
661- cfg , training_client , i_batch
676+ training_client , i_batch , log_path , save_every
662677 )
663678 metrics .update (checkpoint_metrics )
664679
665680 # Compute post-KL metrics if configured
666- if cfg . compute_post_kl :
681+ if do_compute_post_kl :
667682 with timed ("compute_post_kl" , metrics ):
668683 post_kl_metrics = await compute_post_kl (data_D , sampling_client )
669684 metrics .update (post_kl_metrics )
@@ -728,11 +743,13 @@ async def do_train_step_streaming_and_get_sampling_client(
728743 # remove these and train on a smaller batch.
729744 wrapped_trajectory_groups = [g for g in wrapped_trajectory_groups if g is not None ]
730745 data_D , prepare_minibatch_metrics = await prepare_minibatch (
731- cfg ,
732746 [g .env_group_builder for g in wrapped_trajectory_groups ],
733747 [g .trajectory_group for g in wrapped_trajectory_groups ],
734748 tokenizer ,
735749 service_client ,
750+ model_name = cfg .model_name ,
751+ kl_penalty_coef = cfg .kl_penalty_coef ,
752+ kl_discount_factor = cfg .kl_discount_factor ,
736753 )
737754 metrics .update (prepare_minibatch_metrics )
738755
@@ -767,12 +784,14 @@ async def do_train_step_streaming_and_get_sampling_client(
767784 sampling_client ,
768785 full_batch_metrics ,
769786 ) = await compute_full_batch_metrics_and_get_sampling_client (
770- cfg ,
771787 training_client ,
772788 # NOTE: saving the checkpoint as the i + 1 step
773789 i_batch + 1 ,
774790 all_data_D ,
775791 all_training_logprobs_D ,
792+ cfg .log_path ,
793+ cfg .save_every ,
794+ cfg .compute_post_kl ,
776795 )
777796 metrics .update (full_batch_metrics )
778797 return sampling_client , metrics
@@ -793,11 +812,13 @@ async def do_train_step_and_get_sampling_client(
793812
794813 metrics = {}
795814 data_D , prepare_minibatch_metrics = await prepare_minibatch (
796- cfg ,
797815 env_group_builders_P ,
798816 trajectory_groups_P ,
799817 tokenizer ,
800818 service_client ,
819+ model_name = cfg .model_name ,
820+ kl_penalty_coef = cfg .kl_penalty_coef ,
821+ kl_discount_factor = cfg .kl_discount_factor ,
801822 )
802823 metrics .update (prepare_minibatch_metrics )
803824
@@ -811,12 +832,14 @@ async def do_train_step_and_get_sampling_client(
811832 )
812833
813834 sampling_client , full_batch_metrics = await compute_full_batch_metrics_and_get_sampling_client (
814- cfg ,
815835 training_client ,
816836 # NOTE: saving the checkpoint as the i + 1 step
817837 i_batch + 1 ,
818838 data_D ,
819839 training_logprobs_D ,
840+ cfg .log_path ,
841+ cfg .save_every ,
842+ cfg .compute_post_kl ,
820843 )
821844 metrics .update (full_batch_metrics )
822845
@@ -840,7 +863,7 @@ async def do_sync_training(
840863
841864 # Initial sampling client
842865 sampling_client , _ = await save_checkpoint_and_get_sampling_client (
843- cfg , training_client , start_batch
866+ training_client , start_batch , cfg . log_path , cfg . save_every
844867 )
845868
846869 for i_batch in range (start_batch , end_batch ):
@@ -864,7 +887,12 @@ async def do_sync_training(
864887 trajectory_groups_P = await asyncio .gather (
865888 * [
866889 asyncio .create_task (
867- do_group_rollout_and_filter_constant_reward (cfg , sampling_client , builder ),
890+ do_group_rollout_and_filter_constant_reward (
891+ sampling_client ,
892+ builder ,
893+ max_tokens = cfg .max_tokens ,
894+ do_remove_constant_reward_groups = cfg .remove_constant_reward_groups ,
895+ ),
868896 name = f"sample_task_{ i } " ,
869897 )
870898 for i , builder in enumerate (env_group_builders_P )
0 commit comments