Skip to content

Commit 8c80a0d

Browse files
authored
[tinker-cookbook] rl: add tracing library (#28)
1 parent b097d19 commit 8c80a0d

File tree

3 files changed

+615
-10
lines changed

3 files changed

+615
-10
lines changed

tinker_cookbook/rl/train.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,20 @@
4040
from tinker_cookbook.tokenizer_utils import Tokenizer
4141
from tinker_cookbook.utils import ml_log
4242
from tinker_cookbook.utils.misc_utils import safezip, split_list, timed
43+
from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context
4344

4445
logger = logging.getLogger(__name__)
4546

4647

48+
@scope
4749
def _select_representative_inds(scores: list[float], num_inds: int) -> list[int]:
4850
assert num_inds <= len(scores)
4951
sorted_inds = np.argsort(scores)
5052
uniform_inds = np.linspace(0, len(sorted_inds) - 1, num_inds).astype(int)
5153
return [int(sorted_inds[i]) for i in uniform_inds]
5254

5355

56+
@scope
5457
def print_group(traj_group: TrajectoryGroup, tokenizer: Tokenizer):
5558
# Cut down the number of trajectories to print
5659
max_trajs_to_print = 4
@@ -68,6 +71,7 @@ def print_group(traj_group: TrajectoryGroup, tokenizer: Tokenizer):
6871

6972
buf = io.StringIO()
7073

74+
@scope
7175
def bprint(s: str):
7276
print(s, file=buf)
7377

@@ -101,6 +105,7 @@ def bprint(s: str):
101105
logger.info(buf.getvalue().rstrip())
102106

103107

108+
@scope
104109
async def optim_step(
105110
training_client: tinker.TrainingClient,
106111
learning_rate: float,
@@ -111,13 +116,15 @@ async def optim_step(
111116
await optim_step_future.result_async()
112117

113118

119+
@scope
114120
def remove_mask(datum: tinker.Datum) -> tinker.Datum:
115121
return tinker.Datum(
116122
model_input=datum.model_input,
117123
loss_fn_inputs={k: v for k, v in datum.loss_fn_inputs.items() if k != "mask"},
118124
)
119125

120126

127+
@scope
121128
async def forward_backward(
122129
training_client: tinker.TrainingClient,
123130
batch_d: List[tinker.Datum],
@@ -139,6 +146,7 @@ async def forward_backward(
139146
return training_logprobs_D
140147

141148

149+
@scope
142150
async def train_step(
143151
data_D: List[tinker.Datum],
144152
training_client: tinker.TrainingClient,
@@ -211,6 +219,7 @@ class Config:
211219

212220
log_path: str = chz.field(munger=lambda _, s: os.path.expanduser(s))
213221
base_url: str | None = None
222+
enable_trace: bool = False
214223

215224
remove_constant_reward_groups: bool = False
216225
eval_every: int = 20
@@ -221,6 +230,7 @@ class Config:
221230
stream_minibatch_config: StreamMinibatchConfig | None = None
222231

223232

233+
@scope
224234
async def do_sync_training_with_stream_minibatch(
225235
start_batch: int,
226236
end_batch: int,
@@ -264,6 +274,7 @@ async def do_sync_training_with_stream_minibatch(
264274
# and the trainer will consume them as soon as they are ready
265275
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]()
266276

277+
@scope
267278
async def trajectory_group_worker_task(builder: EnvGroupBuilder) -> None:
268279
metrics = {}
269280
t_start = time.time()
@@ -286,8 +297,10 @@ async def trajectory_group_worker_task(builder: EnvGroupBuilder) -> None:
286297
# Sample all trajectories asynchronously. If we have multiple minibatches,
287298
# then sampling can overlap with training.
288299
env_group_builders_P = dataset.get_batch(i_batch)
289-
for builder in env_group_builders_P:
290-
asyncio.create_task(trajectory_group_worker_task(builder))
300+
for i, builder in enumerate(env_group_builders_P):
301+
asyncio.create_task(
302+
trajectory_group_worker_task(builder), name=f"trajectory_group_worker_task_{i}"
303+
)
291304

292305
# Run multiple optimizer substeps per training iteration
293306
sampling_client, full_batch_metrics = await do_train_step_streaming_and_get_sampling_client(
@@ -322,6 +335,7 @@ class WrappedTrajectoryGroup:
322335
metrics: dict[str, Any] = chz.field(default_factory=dict)
323336

324337

338+
@scope
325339
async def do_async_training(
326340
start_batch: int,
327341
end_batch: int,
@@ -360,6 +374,7 @@ async def do_async_training(
360374
sampling_client_updated_event = asyncio.Event()
361375
sampling_client_updated_event.set()
362376

377+
@scope
363378
def shutdown_loops():
364379
"""Trigger all loops to shutdown"""
365380
shutdown_event.set()
@@ -368,6 +383,7 @@ def shutdown_loops():
368383
env_group_builders_queue.put_nowait(None)
369384
sampling_client_updated_event.set()
370385

386+
@scope
371387
async def dataloader_loop():
372388
"""Gets the next set of env builders to run"""
373389
i_batch = start_batch
@@ -377,6 +393,7 @@ async def dataloader_loop():
377393
await env_group_builders_queue.put(env_group_builder)
378394
i_batch += 1
379395

396+
@scope
380397
async def trajectory_group_worker_loop():
381398
"""Generates trajectories for a single env builder"""
382399
while not shutdown_event.is_set():
@@ -407,6 +424,7 @@ async def trajectory_group_worker_loop():
407424
)
408425
)
409426

427+
@scope
410428
async def training_loop():
411429
"""
412430
Waits for a sufficient number of valid trajectories to be accumulated and trains on them.
@@ -421,6 +439,7 @@ async def training_loop():
421439
if wrapped_trajectory_group is None:
422440
continue
423441

442+
@scope
424443
def filter_stale_trajectory_group(
425444
wrapped_trajectory_group: WrappedTrajectoryGroup | None,
426445
) -> bool:
@@ -437,7 +456,8 @@ def filter_stale_trajectory_group(
437456
):
438457
logger.info(f"[training_loop] Step {i_batch}: Samples are too stale, skipping")
439458
asyncio.create_task(
440-
env_group_builders_queue.put(wrapped_trajectory_group.env_group_builder)
459+
env_group_builders_queue.put(wrapped_trajectory_group.env_group_builder),
460+
name="requeue_stale_sample_task",
441461
)
442462
return False
443463
return True
@@ -505,6 +525,7 @@ def filter_stale_trajectory_group(
505525

506526
shutdown_loops()
507527

528+
@scope
508529
async def evaluation_loop():
509530
"""Runs evals periodically"""
510531
if len(evaluators) == 0 or cfg.eval_every == 0:
@@ -529,13 +550,19 @@ async def evaluation_loop():
529550
ml_logger.log_metrics(metrics, step=sampling_client_eval_step)
530551

531552
await asyncio.gather(
532-
dataloader_loop(),
533-
*[trajectory_group_worker_loop() for _ in range(cfg.async_config.groups_per_batch)],
534-
training_loop(),
535-
evaluation_loop(),
553+
asyncio.create_task(dataloader_loop(), name="dataloader_loop"),
554+
*[
555+
asyncio.create_task(
556+
trajectory_group_worker_loop(), name=f"trajectory_group_worker_loop_{i}"
557+
)
558+
for i in range(cfg.async_config.groups_per_batch)
559+
],
560+
asyncio.create_task(training_loop(), name="training_loop"),
561+
asyncio.create_task(evaluation_loop(), name="evaluation_loop"),
536562
)
537563

538564

565+
@scope
539566
async def do_group_rollout_and_filter_constant_reward(
540567
cfg: Config,
541568
sampling_client: tinker.SamplingClient,
@@ -553,6 +580,7 @@ async def do_group_rollout_and_filter_constant_reward(
553580
return trajectory_groups[0]
554581

555582

583+
@scope
556584
async def save_checkpoint_and_get_sampling_client(
557585
cfg: Config,
558586
training_client: tinker.TrainingClient,
@@ -570,6 +598,7 @@ async def save_checkpoint_and_get_sampling_client(
570598
return training_client.create_sampling_client(path_dict["sampler_path"]), metrics
571599

572600

601+
@scope
573602
async def prepare_minibatch(
574603
cfg: Config,
575604
env_group_builders_P: Sequence[EnvGroupBuilder],
@@ -608,6 +637,7 @@ async def prepare_minibatch(
608637
return data_D, metrics
609638

610639

640+
@scope
611641
async def compute_full_batch_metrics_and_get_sampling_client(
612642
cfg: Config,
613643
training_client: tinker.TrainingClient,
@@ -641,6 +671,7 @@ async def compute_full_batch_metrics_and_get_sampling_client(
641671
return sampling_client, metrics
642672

643673

674+
@scope
644675
async def do_train_step_streaming_and_get_sampling_client(
645676
cfg: Config,
646677
i_batch: int,
@@ -666,6 +697,9 @@ async def do_train_step_streaming_and_get_sampling_client(
666697
# Number of groups per minibatch in each optimizer substep
667698
groups_per_minibatch = groups_per_substep // cfg.stream_minibatch_config.num_minibatches
668699

700+
context = get_scope_context()
701+
context.attributes["step"] = i_batch
702+
669703
metrics = {}
670704

671705
# Run multiple optimizer substeps per training iteration
@@ -744,6 +778,7 @@ async def do_train_step_streaming_and_get_sampling_client(
744778
return sampling_client, metrics
745779

746780

781+
@scope
747782
async def do_train_step_and_get_sampling_client(
748783
cfg: Config,
749784
i_batch: int,
@@ -753,6 +788,9 @@ async def do_train_step_and_get_sampling_client(
753788
env_group_builders_P: Sequence[EnvGroupBuilder],
754789
trajectory_groups_P: list[TrajectoryGroup],
755790
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
791+
context = get_scope_context()
792+
context.attributes["step"] = i_batch
793+
756794
metrics = {}
757795
data_D, prepare_minibatch_metrics = await prepare_minibatch(
758796
cfg,
@@ -785,6 +823,7 @@ async def do_train_step_and_get_sampling_client(
785823
return sampling_client, metrics
786824

787825

826+
@scope
788827
async def do_sync_training(
789828
start_batch: int,
790829
end_batch: int,
@@ -824,9 +863,12 @@ async def do_sync_training(
824863
with timed("sample", metrics):
825864
trajectory_groups_P = await asyncio.gather(
826865
*[
827-
do_group_rollout_and_filter_constant_reward(cfg, sampling_client, builder)
828-
for builder in env_group_builders_P
829-
]
866+
asyncio.create_task(
867+
do_group_rollout_and_filter_constant_reward(cfg, sampling_client, builder),
868+
name=f"sample_task_{i}",
869+
)
870+
for i, builder in enumerate(env_group_builders_P)
871+
],
830872
)
831873
trajectory_groups_P = [
832874
trajectory_group
@@ -851,6 +893,7 @@ async def do_sync_training(
851893
ml_logger.log_metrics(metrics, step=i_batch)
852894

853895

896+
@scope
854897
async def main(
855898
cfg: Config,
856899
):
@@ -861,6 +904,18 @@ async def main(
861904
config=cfg,
862905
wandb_name=cfg.wandb_name,
863906
)
907+
if cfg.enable_trace:
908+
# Get and rename the current (main) task
909+
current_task = asyncio.current_task()
910+
if current_task is not None:
911+
current_task.set_name("main")
912+
trace_events_path = os.path.join(cfg.log_path, "trace_events.jsonl")
913+
logger.info(f"Tracing is enabled. Trace events will be saved to {trace_events_path}")
914+
logger.info(
915+
f"Run `python tinker_cookbook/utils/trace.py {trace_events_path} trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/"
916+
)
917+
trace_init(output_file=trace_events_path)
918+
864919
logging.getLogger("httpx").setLevel(logging.WARNING)
865920
logging.getLogger("pylatexenc").setLevel(logging.WARNING)
866921

0 commit comments

Comments
 (0)