Skip to content

Commit 9eef1ae

Browse files
authored
float8 profiling: save torch_logs, and attach them to profiling trace (#645)
Summary: Adds two extra things to benchmarks/float8/profile_linear_float8.py: 1. for each profiling run, also saves the output of `TORCH_LOGS="output_code"` 2. for each profiling trace, creates a modified trace with the inductor code from (1) attached to triton kernel GPU events, which enables the user to click on the triton kernels in `chrome://tracing/` or perfetto, and see the triton code inline in the browser - no need to cross reference against logs anymore. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 261d0a4 commit 9eef1ae

File tree

2 files changed

+233
-38
lines changed

2 files changed

+233
-38
lines changed

benchmarks/float8/profile_linear_float8.py

Lines changed: 103 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,19 @@
66

77
import copy
88
import io
9+
import os
910
import random
1011
from contextlib import nullcontext, redirect_stdout
1112
from dataclasses import dataclass, field
12-
from pathlib import Path
13+
import pathlib
1314
from typing import Callable, Optional
1415

1516
import fire
1617
import pandas as pd
1718

19+
# disable inductor FX cache, so we can can always study the inductor output logs
20+
os.environ['TORCHINDUCTOR_FORCE_DISABLE_CACHES'] = '1'
21+
1822
import torch
1923
import torch.nn as nn
2024
import torch.nn.functional as F
@@ -30,6 +34,7 @@
3034
parse_bw_and_kernel_name,
3135
profiler_output_to_gpu_time_for_key,
3236
profiler_output_to_filtered_time_by_kernel_name,
37+
update_triton_kernels_in_prof_chome_trace_with_torch_logs,
3338
)
3439

3540
# don't truncate long kernel names
@@ -38,6 +43,7 @@
3843
pd.set_option("display.float_format", "{:.3f}".format)
3944

4045

46+
4147
class LNLinear(torch.nn.Module):
4248
def __init__(self, fc_dim1, fc_dim2):
4349
super().__init__()
@@ -151,7 +157,9 @@ def forward(self, h):
151157

152158
@dataclass
153159
class ProfileConfig:
154-
file_path: Optional[str] = None
160+
trace_file_path: Optional[str] = None
161+
logs_file_path: Optional[str] = None
162+
trace_modified_file_path: Optional[str] = None
155163
name: Optional[str] = None
156164
cuda: bool = True
157165
iters: int = 0
@@ -162,13 +170,33 @@ class ProfileConfig:
162170

163171

164172
def profile_function(
165-
config: ProfileConfig, func: Callable, *args, **kwargs
173+
config: ProfileConfig,
174+
func: Callable,
175+
add_inductor_metadata_to_trace: bool,
176+
*args,
177+
**kwargs,
166178
) -> torch.profiler.profile:
167179
"""Profile a torch function and save the result to a file"""
168180
seed = 123
169181
random.seed(seed)
170182
torch.manual_seed(seed)
171183

184+
if add_inductor_metadata_to_trace:
185+
# ensure we aren't interfering with other torch_log settings
186+
if os.environ.get('TORCH_LOGS', '') != '':
187+
raise AssertionError('using TORCH_LOGS together with add_inductor_metadata_to_trace is not supported yet')
188+
189+
# save torch.compile logs to a file specific to this benchmark run
190+
# TODO(future): can we hack torch.compile to print to file only and not stdout?
191+
# or maybe just use tlparse?
192+
torch._logging.set_logs(output_code=True)
193+
# by default torch.compile appends to log_file_name, so we delete it
194+
# if it exists
195+
if os.path.isfile(config.logs_file_path):
196+
pathlib.Path.unlink(config.logs_file_path)
197+
torch._logging._init_logs(log_file_name=config.logs_file_path)
198+
199+
172200
activities = [ProfilerActivity.CPU]
173201
if config.cuda:
174202
activities.append(ProfilerActivity.CUDA)
@@ -182,6 +210,10 @@ def profile_function(
182210
nullcontext() if config.name is None else record_function(config.name)
183211
)
184212
profile_memory = config.memory_profile_path is not None
213+
214+
# warm up
215+
func(*args, **kwargs)
216+
185217
with profile(
186218
activities=activities,
187219
profile_memory=profile_memory,
@@ -195,20 +227,35 @@ def profile_function(
195227
if config.sync:
196228
torch.cuda.synchronize()
197229

198-
if config.file_path is not None:
199-
prof.export_chrome_trace(config.file_path)
230+
if config.trace_file_path is not None:
231+
prof.export_chrome_trace(config.trace_file_path)
232+
233+
if add_inductor_metadata_to_trace:
234+
# modify the trace to have the triton kernel metadata and code
235+
# visible inline
236+
update_triton_kernels_in_prof_chome_trace_with_torch_logs(
237+
config.trace_file_path,
238+
config.logs_file_path,
239+
config.trace_modified_file_path,
240+
)
241+
242+
# undo custom log settings
243+
torch._logging.set_logs(output_code=False)
244+
torch._logging._init_logs(log_file_name=None)
200245

201246
return prof
202247

203248

204249
def main(
205-
profile_path_prefix: Path,
250+
profile_path_prefix: pathlib.Path,
206251
compile: bool = True,
207252
scaling_type_input: str = "dynamic",
208253
scaling_type_weight: str = "dynamic",
209254
scaling_type_grad_output: str = "dynamic",
210255
model_type: str = "linear",
211256
dtype_filter: str = "both",
257+
add_inductor_metadata_to_trace: bool = True,
258+
enable_sync_amax_history: bool = True,
212259
):
213260
assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported"
214261
assert dtype_filter in ("both", "float8", "bfloat16")
@@ -220,6 +267,8 @@ def main(
220267
cast_config_input=CastConfig(scaling_type=scaling_type_input),
221268
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
222269
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
270+
enable_amax_init=False,
271+
enable_pre_and_post_forward=False,
223272
)
224273
scaling_repr = "_".join(
225274
[
@@ -290,7 +339,7 @@ def float8_forw_backward_wrapper(x):
290339
# inspection of the fw+bw torch.compile without the scale
291340
# syncing code
292341
# TODO(future): make this better
293-
if linear_requires_sync(config):
342+
if linear_requires_sync(config) and enable_sync_amax_history:
294343
with record_function("scale_amax_and_scales"):
295344
sync_amax_history(m_float8)
296345
out = float8_forw(x)
@@ -311,16 +360,14 @@ def float8_forw_backward_wrapper(x):
311360

312361
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
313362
# to populate triton kernel bandwidth further down in the script
314-
f = io.StringIO()
363+
if os.environ.get("TORCHINDUCTOR_PROFILE", "") != "":
364+
context = nullcontext()
365+
f = None
366+
else:
367+
f = io.StringIO()
368+
context = redirect_stdout(f)
315369
try:
316-
with redirect_stdout(f):
317-
# warm up
318-
for _ in range(1):
319-
if dtype_filter != "float8":
320-
ref_forw_backward(input_tensor)
321-
if dtype_filter != "bfloat16":
322-
float8_forw_backward_wrapper(input_tensor)
323-
370+
with context:
324371
profile_iters = 5
325372
ref_times, float8_times = None, None
326373
data = []
@@ -330,13 +377,19 @@ def float8_forw_backward_wrapper(x):
330377
if dtype_filter != "float8":
331378
# Profile Reference Model
332379
print("profiling ref")
333-
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
334-
ref_path = profile_path_prefix + ref_suffix
380+
ref_trace_suffix = f"_{model_type}_ref_compile_{compile}.json"
381+
ref_logs_suffix = f"_{model_type}_ref_compile_{compile}.txt"
382+
trace_ref_path = profile_path_prefix + ref_trace_suffix
383+
log_ref_path = profile_path_prefix + ref_logs_suffix
384+
trace_ref_modified_path = trace_ref_path.replace(".json", "_modified.json")
335385
profile_config = ProfileConfig(
336-
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
386+
trace_ref_path, log_ref_path, trace_ref_modified_path, ref_trace_suffix, iters=profile_iters, warmup_iters=2, sync=True
337387
)
338-
p = profile_function(profile_config, ref_forw_backward, input_tensor)
339-
print(f"saved {ref_path}")
388+
p = profile_function(profile_config, ref_forw_backward, add_inductor_metadata_to_trace, input_tensor)
389+
print(f"saved profiling trace to {trace_ref_path}")
390+
if add_inductor_metadata_to_trace:
391+
print(f"saved torch logs to {log_ref_path}")
392+
print(f"saved modified trace to {trace_ref_modified_path}")
340393
ref_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
341394
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
342395
for k, v in ref_times.items():
@@ -355,21 +408,31 @@ def float8_forw_backward_wrapper(x):
355408
if dtype_filter != "bfloat16":
356409
# Profile Float8 Model
357410
print("profiling float8")
358-
float8_suffix = (
411+
float8_trace_suffix = (
359412
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
360413
)
361-
float8_path = profile_path_prefix + float8_suffix
414+
float8_log_suffix = (
415+
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.txt"
416+
)
417+
trace_float8_path = profile_path_prefix + float8_trace_suffix
418+
log_float8_path = profile_path_prefix + float8_log_suffix
419+
trace_float8_modified_path = trace_float8_path.replace(".json", "_modified.json")
362420
profile_config = ProfileConfig(
363-
float8_path,
364-
float8_suffix,
421+
trace_float8_path,
422+
log_float8_path,
423+
trace_float8_modified_path,
424+
float8_trace_suffix,
365425
iters=profile_iters,
366426
warmup_iters=2,
367427
sync=True,
368428
)
369429
p = profile_function(
370-
profile_config, float8_forw_backward_wrapper, input_tensor
430+
profile_config, float8_forw_backward_wrapper, add_inductor_metadata_to_trace, input_tensor
371431
)
372-
print(f"saved {float8_path}")
432+
print(f"saved profiling trace to {trace_float8_path}")
433+
if add_inductor_metadata_to_trace:
434+
print(f"saved torch logs to {log_float8_path}")
435+
print(f"saved modified trace to {trace_float8_modified_path}")
373436
float8_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
374437
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
375438
for k, v in float8_times.items():
@@ -393,17 +456,19 @@ def float8_forw_backward_wrapper(x):
393456
print(f"Sync time ms: {sync_time_ms}")
394457

395458
finally:
396-
# print the redirected stdout back to regular stdout
397-
print(f.getvalue())
398-
399-
# populate the triton kernel bandwidth
400-
for line in f.getvalue().split("\n"):
401-
maybe_bw, maybe_kernel_name = parse_bw_and_kernel_name(line)
402-
if maybe_kernel_name is not None:
403-
# O(N) search, but it's ok since lists are small
404-
for datum in data:
405-
if datum[1] == maybe_kernel_name:
406-
datum[-1] = maybe_bw
459+
if f is not None:
460+
# print the redirected stdout back to regular stdout
461+
print(f.getvalue())
462+
463+
if os.environ.get("TORCHINDUCTOR_PROFILE", "") != "":
464+
# populate the triton kernel bandwidth
465+
for line in f.getvalue().split("\n"):
466+
maybe_bw, maybe_kernel_name = parse_bw_and_kernel_name(line)
467+
if maybe_kernel_name is not None:
468+
# O(N) search, but it's ok since lists are small
469+
for datum in data:
470+
if datum[1] == maybe_kernel_name:
471+
datum[-1] = maybe_bw
407472

408473
df = pd.DataFrame(
409474
data,

0 commit comments

Comments
 (0)