4040from tinker_cookbook .tokenizer_utils import Tokenizer
4141from tinker_cookbook .utils import ml_log
4242from tinker_cookbook .utils .misc_utils import safezip , split_list , timed
43+ from tinker_cookbook .utils .trace import scope , trace_init , get_scope_context
4344
4445logger = logging .getLogger (__name__ )
4546
4647
48+ @scope
4749def _select_representative_inds (scores : list [float ], num_inds : int ) -> list [int ]:
4850 assert num_inds <= len (scores )
4951 sorted_inds = np .argsort (scores )
5052 uniform_inds = np .linspace (0 , len (sorted_inds ) - 1 , num_inds ).astype (int )
5153 return [int (sorted_inds [i ]) for i in uniform_inds ]
5254
5355
56+ @scope
5457def print_group (traj_group : TrajectoryGroup , tokenizer : Tokenizer ):
5558 # Cut down the number of trajectories to print
5659 max_trajs_to_print = 4
@@ -68,6 +71,7 @@ def print_group(traj_group: TrajectoryGroup, tokenizer: Tokenizer):
6871
6972 buf = io .StringIO ()
7073
74+ @scope
7175 def bprint (s : str ):
7276 print (s , file = buf )
7377
@@ -101,6 +105,7 @@ def bprint(s: str):
101105 logger .info (buf .getvalue ().rstrip ())
102106
103107
108+ @scope
104109async def optim_step (
105110 training_client : tinker .TrainingClient ,
106111 learning_rate : float ,
@@ -111,13 +116,15 @@ async def optim_step(
111116 await optim_step_future .result_async ()
112117
113118
119+ @scope
114120def remove_mask (datum : tinker .Datum ) -> tinker .Datum :
115121 return tinker .Datum (
116122 model_input = datum .model_input ,
117123 loss_fn_inputs = {k : v for k , v in datum .loss_fn_inputs .items () if k != "mask" },
118124 )
119125
120126
127+ @scope
121128async def forward_backward (
122129 training_client : tinker .TrainingClient ,
123130 batch_d : List [tinker .Datum ],
@@ -139,6 +146,7 @@ async def forward_backward(
139146 return training_logprobs_D
140147
141148
149+ @scope
142150async def train_step (
143151 data_D : List [tinker .Datum ],
144152 training_client : tinker .TrainingClient ,
@@ -211,6 +219,7 @@ class Config:
211219
212220 log_path : str = chz .field (munger = lambda _ , s : os .path .expanduser (s ))
213221 base_url : str | None = None
222+ enable_trace : bool = False
214223
215224 remove_constant_reward_groups : bool = False
216225 eval_every : int = 20
@@ -221,6 +230,7 @@ class Config:
221230 stream_minibatch_config : StreamMinibatchConfig | None = None
222231
223232
233+ @scope
224234async def do_sync_training_with_stream_minibatch (
225235 start_batch : int ,
226236 end_batch : int ,
@@ -264,6 +274,7 @@ async def do_sync_training_with_stream_minibatch(
264274 # and the trainer will consume them as soon as they are ready
265275 trajectory_groups_queue = asyncio .Queue [WrappedTrajectoryGroup | None ]()
266276
277+ @scope
267278 async def trajectory_group_worker_task (builder : EnvGroupBuilder ) -> None :
268279 metrics = {}
269280 t_start = time .time ()
@@ -286,8 +297,10 @@ async def trajectory_group_worker_task(builder: EnvGroupBuilder) -> None:
286297 # Sample all trajectories asynchronously. If we have multiple minibatches,
287298 # then sampling can overlap with training.
288299 env_group_builders_P = dataset .get_batch (i_batch )
289- for builder in env_group_builders_P :
290- asyncio .create_task (trajectory_group_worker_task (builder ))
300+ for i , builder in enumerate (env_group_builders_P ):
301+ asyncio .create_task (
302+ trajectory_group_worker_task (builder ), name = f"trajectory_group_worker_task_{ i } "
303+ )
291304
292305 # Run multiple optimizer substeps per training iteration
293306 sampling_client , full_batch_metrics = await do_train_step_streaming_and_get_sampling_client (
@@ -322,6 +335,7 @@ class WrappedTrajectoryGroup:
322335 metrics : dict [str , Any ] = chz .field (default_factory = dict )
323336
324337
338+ @scope
325339async def do_async_training (
326340 start_batch : int ,
327341 end_batch : int ,
@@ -360,6 +374,7 @@ async def do_async_training(
360374 sampling_client_updated_event = asyncio .Event ()
361375 sampling_client_updated_event .set ()
362376
377+ @scope
363378 def shutdown_loops ():
364379 """Trigger all loops to shutdown"""
365380 shutdown_event .set ()
@@ -368,6 +383,7 @@ def shutdown_loops():
368383 env_group_builders_queue .put_nowait (None )
369384 sampling_client_updated_event .set ()
370385
386+ @scope
371387 async def dataloader_loop ():
372388 """Gets the next set of env builders to run"""
373389 i_batch = start_batch
@@ -377,6 +393,7 @@ async def dataloader_loop():
377393 await env_group_builders_queue .put (env_group_builder )
378394 i_batch += 1
379395
396+ @scope
380397 async def trajectory_group_worker_loop ():
381398 """Generates trajectories for a single env builder"""
382399 while not shutdown_event .is_set ():
@@ -407,6 +424,7 @@ async def trajectory_group_worker_loop():
407424 )
408425 )
409426
427+ @scope
410428 async def training_loop ():
411429 """
412430 Waits for a sufficient number of valid trajectories to be accumulated and trains on them.
@@ -421,6 +439,7 @@ async def training_loop():
421439 if wrapped_trajectory_group is None :
422440 continue
423441
442+ @scope
424443 def filter_stale_trajectory_group (
425444 wrapped_trajectory_group : WrappedTrajectoryGroup | None ,
426445 ) -> bool :
@@ -437,7 +456,8 @@ def filter_stale_trajectory_group(
437456 ):
438457 logger .info (f"[training_loop] Step { i_batch } : Samples are too stale, skipping" )
439458 asyncio .create_task (
440- env_group_builders_queue .put (wrapped_trajectory_group .env_group_builder )
459+ env_group_builders_queue .put (wrapped_trajectory_group .env_group_builder ),
460+ name = "requeue_stale_sample_task" ,
441461 )
442462 return False
443463 return True
@@ -505,6 +525,7 @@ def filter_stale_trajectory_group(
505525
506526 shutdown_loops ()
507527
528+ @scope
508529 async def evaluation_loop ():
509530 """Runs evals periodically"""
510531 if len (evaluators ) == 0 or cfg .eval_every == 0 :
@@ -529,13 +550,19 @@ async def evaluation_loop():
529550 ml_logger .log_metrics (metrics , step = sampling_client_eval_step )
530551
531552 await asyncio .gather (
532- dataloader_loop (),
533- * [trajectory_group_worker_loop () for _ in range (cfg .async_config .groups_per_batch )],
534- training_loop (),
535- evaluation_loop (),
553+ asyncio .create_task (dataloader_loop (), name = "dataloader_loop" ),
554+ * [
555+ asyncio .create_task (
556+ trajectory_group_worker_loop (), name = f"trajectory_group_worker_loop_{ i } "
557+ )
558+ for i in range (cfg .async_config .groups_per_batch )
559+ ],
560+ asyncio .create_task (training_loop (), name = "training_loop" ),
561+ asyncio .create_task (evaluation_loop (), name = "evaluation_loop" ),
536562 )
537563
538564
565+ @scope
539566async def do_group_rollout_and_filter_constant_reward (
540567 cfg : Config ,
541568 sampling_client : tinker .SamplingClient ,
@@ -553,6 +580,7 @@ async def do_group_rollout_and_filter_constant_reward(
553580 return trajectory_groups [0 ]
554581
555582
583+ @scope
556584async def save_checkpoint_and_get_sampling_client (
557585 cfg : Config ,
558586 training_client : tinker .TrainingClient ,
@@ -570,6 +598,7 @@ async def save_checkpoint_and_get_sampling_client(
570598 return training_client .create_sampling_client (path_dict ["sampler_path" ]), metrics
571599
572600
601+ @scope
573602async def prepare_minibatch (
574603 cfg : Config ,
575604 env_group_builders_P : Sequence [EnvGroupBuilder ],
@@ -608,6 +637,7 @@ async def prepare_minibatch(
608637 return data_D , metrics
609638
610639
640+ @scope
611641async def compute_full_batch_metrics_and_get_sampling_client (
612642 cfg : Config ,
613643 training_client : tinker .TrainingClient ,
@@ -641,6 +671,7 @@ async def compute_full_batch_metrics_and_get_sampling_client(
641671 return sampling_client , metrics
642672
643673
674+ @scope
644675async def do_train_step_streaming_and_get_sampling_client (
645676 cfg : Config ,
646677 i_batch : int ,
@@ -666,6 +697,9 @@ async def do_train_step_streaming_and_get_sampling_client(
666697 # Number of groups per minibatch in each optimizer substep
667698 groups_per_minibatch = groups_per_substep // cfg .stream_minibatch_config .num_minibatches
668699
700+ context = get_scope_context ()
701+ context .attributes ["step" ] = i_batch
702+
669703 metrics = {}
670704
671705 # Run multiple optimizer substeps per training iteration
@@ -744,6 +778,7 @@ async def do_train_step_streaming_and_get_sampling_client(
744778 return sampling_client , metrics
745779
746780
781+ @scope
747782async def do_train_step_and_get_sampling_client (
748783 cfg : Config ,
749784 i_batch : int ,
@@ -753,6 +788,9 @@ async def do_train_step_and_get_sampling_client(
753788 env_group_builders_P : Sequence [EnvGroupBuilder ],
754789 trajectory_groups_P : list [TrajectoryGroup ],
755790) -> tuple [tinker .SamplingClient , dict [str , Any ]]:
791+ context = get_scope_context ()
792+ context .attributes ["step" ] = i_batch
793+
756794 metrics = {}
757795 data_D , prepare_minibatch_metrics = await prepare_minibatch (
758796 cfg ,
@@ -785,6 +823,7 @@ async def do_train_step_and_get_sampling_client(
785823 return sampling_client , metrics
786824
787825
826+ @scope
788827async def do_sync_training (
789828 start_batch : int ,
790829 end_batch : int ,
@@ -824,9 +863,12 @@ async def do_sync_training(
824863 with timed ("sample" , metrics ):
825864 trajectory_groups_P = await asyncio .gather (
826865 * [
827- do_group_rollout_and_filter_constant_reward (cfg , sampling_client , builder )
828- for builder in env_group_builders_P
829- ]
866+ asyncio .create_task (
867+ do_group_rollout_and_filter_constant_reward (cfg , sampling_client , builder ),
868+ name = f"sample_task_{ i } " ,
869+ )
870+ for i , builder in enumerate (env_group_builders_P )
871+ ],
830872 )
831873 trajectory_groups_P = [
832874 trajectory_group
@@ -851,6 +893,7 @@ async def do_sync_training(
851893 ml_logger .log_metrics (metrics , step = i_batch )
852894
853895
896+ @scope
854897async def main (
855898 cfg : Config ,
856899):
@@ -861,6 +904,18 @@ async def main(
861904 config = cfg ,
862905 wandb_name = cfg .wandb_name ,
863906 )
907+ if cfg .enable_trace :
908+ # Get and rename the current (main) task
909+ current_task = asyncio .current_task ()
910+ if current_task is not None :
911+ current_task .set_name ("main" )
912+ trace_events_path = os .path .join (cfg .log_path , "trace_events.jsonl" )
913+ logger .info (f"Tracing is enabled. Trace events will be saved to { trace_events_path } " )
914+ logger .info (
915+ f"Run `python tinker_cookbook/utils/trace.py { trace_events_path } trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/"
916+ )
917+ trace_init (output_file = trace_events_path )
918+
864919 logging .getLogger ("httpx" ).setLevel (logging .WARNING )
865920 logging .getLogger ("pylatexenc" ).setLevel (logging .WARNING )
866921
0 commit comments