6
6
7
7
import copy
8
8
import io
9
+ import os
9
10
import random
10
11
from contextlib import nullcontext , redirect_stdout
11
12
from dataclasses import dataclass , field
12
- from pathlib import Path
13
+ import pathlib
13
14
from typing import Callable , Optional
14
15
15
16
import fire
16
17
import pandas as pd
17
18
19
+ # disable inductor FX cache, so we can can always study the inductor output logs
20
+ os .environ ['TORCHINDUCTOR_FORCE_DISABLE_CACHES' ] = '1'
21
+
18
22
import torch
19
23
import torch .nn as nn
20
24
import torch .nn .functional as F
30
34
parse_bw_and_kernel_name ,
31
35
profiler_output_to_gpu_time_for_key ,
32
36
profiler_output_to_filtered_time_by_kernel_name ,
37
+ update_triton_kernels_in_prof_chome_trace_with_torch_logs ,
33
38
)
34
39
35
40
# don't truncate long kernel names
38
43
pd .set_option ("display.float_format" , "{:.3f}" .format )
39
44
40
45
46
+
41
47
class LNLinear (torch .nn .Module ):
42
48
def __init__ (self , fc_dim1 , fc_dim2 ):
43
49
super ().__init__ ()
@@ -151,7 +157,9 @@ def forward(self, h):
151
157
152
158
@dataclass
153
159
class ProfileConfig :
154
- file_path : Optional [str ] = None
160
+ trace_file_path : Optional [str ] = None
161
+ logs_file_path : Optional [str ] = None
162
+ trace_modified_file_path : Optional [str ] = None
155
163
name : Optional [str ] = None
156
164
cuda : bool = True
157
165
iters : int = 0
@@ -162,13 +170,33 @@ class ProfileConfig:
162
170
163
171
164
172
def profile_function (
165
- config : ProfileConfig , func : Callable , * args , ** kwargs
173
+ config : ProfileConfig ,
174
+ func : Callable ,
175
+ add_inductor_metadata_to_trace : bool ,
176
+ * args ,
177
+ ** kwargs ,
166
178
) -> torch .profiler .profile :
167
179
"""Profile a torch function and save the result to a file"""
168
180
seed = 123
169
181
random .seed (seed )
170
182
torch .manual_seed (seed )
171
183
184
+ if add_inductor_metadata_to_trace :
185
+ # ensure we aren't interfering with other torch_log settings
186
+ if os .environ .get ('TORCH_LOGS' , '' ) != '' :
187
+ raise AssertionError ('using TORCH_LOGS together with add_inductor_metadata_to_trace is not supported yet' )
188
+
189
+ # save torch.compile logs to a file specific to this benchmark run
190
+ # TODO(future): can we hack torch.compile to print to file only and not stdout?
191
+ # or maybe just use tlparse?
192
+ torch ._logging .set_logs (output_code = True )
193
+ # by default torch.compile appends to log_file_name, so we delete it
194
+ # if it exists
195
+ if os .path .isfile (config .logs_file_path ):
196
+ pathlib .Path .unlink (config .logs_file_path )
197
+ torch ._logging ._init_logs (log_file_name = config .logs_file_path )
198
+
199
+
172
200
activities = [ProfilerActivity .CPU ]
173
201
if config .cuda :
174
202
activities .append (ProfilerActivity .CUDA )
@@ -182,6 +210,10 @@ def profile_function(
182
210
nullcontext () if config .name is None else record_function (config .name )
183
211
)
184
212
profile_memory = config .memory_profile_path is not None
213
+
214
+ # warm up
215
+ func (* args , ** kwargs )
216
+
185
217
with profile (
186
218
activities = activities ,
187
219
profile_memory = profile_memory ,
@@ -195,20 +227,35 @@ def profile_function(
195
227
if config .sync :
196
228
torch .cuda .synchronize ()
197
229
198
- if config .file_path is not None :
199
- prof .export_chrome_trace (config .file_path )
230
+ if config .trace_file_path is not None :
231
+ prof .export_chrome_trace (config .trace_file_path )
232
+
233
+ if add_inductor_metadata_to_trace :
234
+ # modify the trace to have the triton kernel metadata and code
235
+ # visible inline
236
+ update_triton_kernels_in_prof_chome_trace_with_torch_logs (
237
+ config .trace_file_path ,
238
+ config .logs_file_path ,
239
+ config .trace_modified_file_path ,
240
+ )
241
+
242
+ # undo custom log settings
243
+ torch ._logging .set_logs (output_code = False )
244
+ torch ._logging ._init_logs (log_file_name = None )
200
245
201
246
return prof
202
247
203
248
204
249
def main (
205
- profile_path_prefix : Path ,
250
+ profile_path_prefix : pathlib . Path ,
206
251
compile : bool = True ,
207
252
scaling_type_input : str = "dynamic" ,
208
253
scaling_type_weight : str = "dynamic" ,
209
254
scaling_type_grad_output : str = "dynamic" ,
210
255
model_type : str = "linear" ,
211
256
dtype_filter : str = "both" ,
257
+ add_inductor_metadata_to_trace : bool = True ,
258
+ enable_sync_amax_history : bool = True ,
212
259
):
213
260
assert model_type in ("linear" , "ln_linear" , "norm_ffn_norm" , "norm_ffn_norm_small" ), "unsupported"
214
261
assert dtype_filter in ("both" , "float8" , "bfloat16" )
@@ -220,6 +267,8 @@ def main(
220
267
cast_config_input = CastConfig (scaling_type = scaling_type_input ),
221
268
cast_config_weight = CastConfig (scaling_type = scaling_type_weight ),
222
269
cast_config_grad_output = CastConfig (scaling_type = scaling_type_grad_output ),
270
+ enable_amax_init = False ,
271
+ enable_pre_and_post_forward = False ,
223
272
)
224
273
scaling_repr = "_" .join (
225
274
[
@@ -290,7 +339,7 @@ def float8_forw_backward_wrapper(x):
290
339
# inspection of the fw+bw torch.compile without the scale
291
340
# syncing code
292
341
# TODO(future): make this better
293
- if linear_requires_sync (config ):
342
+ if linear_requires_sync (config ) and enable_sync_amax_history :
294
343
with record_function ("scale_amax_and_scales" ):
295
344
sync_amax_history (m_float8 )
296
345
out = float8_forw (x )
@@ -311,16 +360,14 @@ def float8_forw_backward_wrapper(x):
311
360
312
361
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
313
362
# to populate triton kernel bandwidth further down in the script
314
- f = io .StringIO ()
363
+ if os .environ .get ("TORCHINDUCTOR_PROFILE" , "" ) != "" :
364
+ context = nullcontext ()
365
+ f = None
366
+ else :
367
+ f = io .StringIO ()
368
+ context = redirect_stdout (f )
315
369
try :
316
- with redirect_stdout (f ):
317
- # warm up
318
- for _ in range (1 ):
319
- if dtype_filter != "float8" :
320
- ref_forw_backward (input_tensor )
321
- if dtype_filter != "bfloat16" :
322
- float8_forw_backward_wrapper (input_tensor )
323
-
370
+ with context :
324
371
profile_iters = 5
325
372
ref_times , float8_times = None , None
326
373
data = []
@@ -330,13 +377,19 @@ def float8_forw_backward_wrapper(x):
330
377
if dtype_filter != "float8" :
331
378
# Profile Reference Model
332
379
print ("profiling ref" )
333
- ref_suffix = f"_{ model_type } _ref_compile_{ compile } .json"
334
- ref_path = profile_path_prefix + ref_suffix
380
+ ref_trace_suffix = f"_{ model_type } _ref_compile_{ compile } .json"
381
+ ref_logs_suffix = f"_{ model_type } _ref_compile_{ compile } .txt"
382
+ trace_ref_path = profile_path_prefix + ref_trace_suffix
383
+ log_ref_path = profile_path_prefix + ref_logs_suffix
384
+ trace_ref_modified_path = trace_ref_path .replace (".json" , "_modified.json" )
335
385
profile_config = ProfileConfig (
336
- ref_path , ref_suffix , iters = profile_iters , warmup_iters = 2 , sync = True
386
+ trace_ref_path , log_ref_path , trace_ref_modified_path , ref_trace_suffix , iters = profile_iters , warmup_iters = 2 , sync = True
337
387
)
338
- p = profile_function (profile_config , ref_forw_backward , input_tensor )
339
- print (f"saved { ref_path } " )
388
+ p = profile_function (profile_config , ref_forw_backward , add_inductor_metadata_to_trace , input_tensor )
389
+ print (f"saved profiling trace to { trace_ref_path } " )
390
+ if add_inductor_metadata_to_trace :
391
+ print (f"saved torch logs to { log_ref_path } " )
392
+ print (f"saved modified trace to { trace_ref_modified_path } " )
340
393
ref_times = profiler_output_to_filtered_time_by_kernel_name (p , profile_iters , num_leaf_tensors )
341
394
total_time_ms = sum (v for v in ref_times .values ()) / 1e3 / profile_iters
342
395
for k , v in ref_times .items ():
@@ -355,21 +408,31 @@ def float8_forw_backward_wrapper(x):
355
408
if dtype_filter != "bfloat16" :
356
409
# Profile Float8 Model
357
410
print ("profiling float8" )
358
- float8_suffix = (
411
+ float8_trace_suffix = (
359
412
f"_{ model_type } _float8_compile_{ compile } _{ scaling_repr } .json"
360
413
)
361
- float8_path = profile_path_prefix + float8_suffix
414
+ float8_log_suffix = (
415
+ f"_{ model_type } _float8_compile_{ compile } _{ scaling_repr } .txt"
416
+ )
417
+ trace_float8_path = profile_path_prefix + float8_trace_suffix
418
+ log_float8_path = profile_path_prefix + float8_log_suffix
419
+ trace_float8_modified_path = trace_float8_path .replace (".json" , "_modified.json" )
362
420
profile_config = ProfileConfig (
363
- float8_path ,
364
- float8_suffix ,
421
+ trace_float8_path ,
422
+ log_float8_path ,
423
+ trace_float8_modified_path ,
424
+ float8_trace_suffix ,
365
425
iters = profile_iters ,
366
426
warmup_iters = 2 ,
367
427
sync = True ,
368
428
)
369
429
p = profile_function (
370
- profile_config , float8_forw_backward_wrapper , input_tensor
430
+ profile_config , float8_forw_backward_wrapper , add_inductor_metadata_to_trace , input_tensor
371
431
)
372
- print (f"saved { float8_path } " )
432
+ print (f"saved profiling trace to { trace_float8_path } " )
433
+ if add_inductor_metadata_to_trace :
434
+ print (f"saved torch logs to { log_float8_path } " )
435
+ print (f"saved modified trace to { trace_float8_modified_path } " )
373
436
float8_times = profiler_output_to_filtered_time_by_kernel_name (p , profile_iters , num_leaf_tensors )
374
437
total_time_ms = sum (v for v in float8_times .values ()) / 1e3 / profile_iters
375
438
for k , v in float8_times .items ():
@@ -393,17 +456,19 @@ def float8_forw_backward_wrapper(x):
393
456
print (f"Sync time ms: { sync_time_ms } " )
394
457
395
458
finally :
396
- # print the redirected stdout back to regular stdout
397
- print (f .getvalue ())
398
-
399
- # populate the triton kernel bandwidth
400
- for line in f .getvalue ().split ("\n " ):
401
- maybe_bw , maybe_kernel_name = parse_bw_and_kernel_name (line )
402
- if maybe_kernel_name is not None :
403
- # O(N) search, but it's ok since lists are small
404
- for datum in data :
405
- if datum [1 ] == maybe_kernel_name :
406
- datum [- 1 ] = maybe_bw
459
+ if f is not None :
460
+ # print the redirected stdout back to regular stdout
461
+ print (f .getvalue ())
462
+
463
+ if os .environ .get ("TORCHINDUCTOR_PROFILE" , "" ) != "" :
464
+ # populate the triton kernel bandwidth
465
+ for line in f .getvalue ().split ("\n " ):
466
+ maybe_bw , maybe_kernel_name = parse_bw_and_kernel_name (line )
467
+ if maybe_kernel_name is not None :
468
+ # O(N) search, but it's ok since lists are small
469
+ for datum in data :
470
+ if datum [1 ] == maybe_kernel_name :
471
+ datum [- 1 ] = maybe_bw
407
472
408
473
df = pd .DataFrame (
409
474
data ,
0 commit comments