Skip to content

Commit fd75d7c

Browse files
committed
Cpu profile
1 parent 2ad418c commit fd75d7c

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

src/main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def get_arg_parser() -> ArgumentParser:
5858
parser.add_argument("--max_log_outputs", type=int)
5959
parser.add_argument("--breakdown_latency", "--bl", action="store_true")
6060
parser.add_argument("--profile", "-p", action="store_true")
61+
parser.add_argument("--profile_cpu", "--pcpu", action="store_true")
6162
parser.add_argument("--profile_cycles", "--pc", type=int)
6263
parser.add_argument("--full_trace", "--pt", action="store_true")
6364
parser.add_argument("--show_op_names", "--pn", action="store_true")
@@ -108,13 +109,16 @@ def main(argv: Optional[List[str]] = None) -> None:
108109

109110
all_metrics = []
110111

111-
if args.profile:
112+
profile = args.profile or args.profile_cpu
113+
114+
if profile:
112115
profiler = get_profiler(
113116
skip=args.skip + pre_warmup_cycles,
114117
warmup=warmup,
115118
cycles=post_warmup_cycles,
116119
full_trace=args.full_trace,
117120
show_op_names=args.show_op_names,
121+
cpu=args.profile_cpu,
118122
)
119123
else:
120124
profiler = contextlib.nullcontext()
@@ -125,7 +129,7 @@ def main(argv: Optional[List[str]] = None) -> None:
125129
"Cycles (warmup)": args.skip + warmup,
126130
"Cycles (benchmark)": args.cycles,
127131
}
128-
if args.profile:
132+
if profile:
129133
benchmark_metrics["Cycles (profile)"] = post_warmup_cycles
130134
benchmark_metrics["Cycles (total)"] = args.skip + warmup + pre_warmup_cycles + post_warmup_cycles
131135

@@ -158,7 +162,7 @@ def main(argv: Optional[List[str]] = None) -> None:
158162
ignore_oom=args.ignore_oom,
159163
pad_generated_tokens=args.pad_generated_tokens,
160164
)
161-
if args.profile:
165+
if profile:
162166
p.step()
163167

164168
if step == 0:

src/pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,8 @@ def _generate_textgen(
619619
with torch.inference_mode():
620620
for key_length in range(input_length, output_length, key_length_step):
621621
try:
622-
if (key_length_step > 1 and key_length>key_length) or not use_cache or not do_prefill:
623-
if not hasattr(self.model,"fast_forward"):
622+
if (key_length_step > 1 and key_length > key_length) or not use_cache or not do_prefill:
623+
if not hasattr(self.model, "fast_forward"):
624624
raise NotImplementedError()
625625
self.model.fast_forward(batch, key_length, use_cache)
626626
last_time = self._get_time(breakdown_latency)
@@ -718,7 +718,7 @@ def __call__(
718718
Metrics.LATENCY_E2E: t1 - t0,
719719
}
720720

721-
output_text=[i+o for i, o in zip(text, output_text)]
721+
output_text = [i + o for i, o in zip(text, output_text)]
722722

723723
return output_text, metrics
724724

src/profile.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,33 @@
1010
logger = logging.getLogger(__name__)
1111

1212

13-
def get_trace_fn(full_trace: bool = False, show_op_names: bool = False, rank: int = -1):
13+
def get_trace_fn(full_trace: bool = False, show_op_names: bool = False, rank: int = -1, cpu: bool = False):
1414
def trace_fn(
1515
p: torch.profiler.profile,
1616
):
1717
averages = p.key_averages()
18+
var_name = f"self_{'cpu' if cpu else 'cuda'}_time_total"
1819
if full_trace:
1920
# Show every GPU op.
2021
# Exclude CPU cuda ops to shorten the table.
2122
events = torch.autograd.profiler.EventList(
22-
[evt for evt in p.profiler.function_events if evt.self_cuda_time_total > 0]
23+
[evt for evt in p.profiler.function_events if getattr(evt, var_name) > 0]
2324
)
2425
log_rank_n(events.table(row_limit=-1, max_src_column_width=1000), logger.info, rank)
2526

2627
if show_op_names:
2728
# Show non-cropped names, in the same order as in the table.
2829
averages_sorted = torch.autograd.profiler.EventList(
29-
sorted(averages, key=lambda evt: evt.self_cuda_time_total, reverse=True)
30+
sorted(averages, key=lambda evt: getattr(evt, var_name), reverse=True)
3031
)
3132
for entry in averages_sorted:
3233
log_rank_n(entry.key, logger.info, rank)
3334

3435
# Try to avoid name cropping, still hard-coded to max 55 characters
35-
log_rank_n(
36-
averages.table(sort_by="self_cuda_time_total", row_limit=-1, max_src_column_width=1000), logger.info, rank
37-
)
36+
log_rank_n(averages.table(sort_by=var_name, row_limit=-1, max_src_column_width=1000), logger.info, rank)
37+
38+
# Store results for future use.
39+
p.bc_profile_result = p.profiler.function_events
3840

3941
return trace_fn
4042

@@ -45,6 +47,7 @@ def get_profiler(
4547
cycles: int,
4648
full_trace: bool = False,
4749
show_op_names: bool = False,
50+
cpu=False,
4851
) -> Union[torch.profiler.profile, contextlib.nullcontext]:
4952
schedule = torch.profiler.schedule(
5053
# Warmup is a must if measuring speed as it's when all the optimizations are performed
@@ -57,6 +60,7 @@ def get_profiler(
5760
)
5861
return torch.profiler.profile(
5962
schedule=schedule,
60-
activities=[torch.profiler.ProfilerActivity.CUDA],
61-
on_trace_ready=get_trace_fn(full_trace, show_op_names),
63+
activities=[torch.profiler.ProfilerActivity.CPU if cpu else torch.profiler.ProfilerActivity.CUDA],
64+
on_trace_ready=get_trace_fn(full_trace, show_op_names, cpu=cpu),
65+
with_modules=True,
6266
)

0 commit comments

Comments
 (0)