10
10
logger = logging .getLogger (__name__ )
11
11
12
12
13
- def get_trace_fn (full_trace : bool = False , show_op_names : bool = False , rank : int = - 1 ):
13
+ def get_trace_fn (full_trace : bool = False , show_op_names : bool = False , rank : int = - 1 , cpu : bool = False ):
14
14
def trace_fn (
15
15
p : torch .profiler .profile ,
16
16
):
17
17
averages = p .key_averages ()
18
+ var_name = f"self_{ 'cpu' if cpu else 'cuda' } _time_total"
18
19
if full_trace :
19
20
# Show every GPU op.
20
21
# Exclude CPU cuda ops to shorten the table.
21
22
events = torch .autograd .profiler .EventList (
22
- [evt for evt in p .profiler .function_events if evt . self_cuda_time_total > 0 ]
23
+ [evt for evt in p .profiler .function_events if getattr ( evt , var_name ) > 0 ]
23
24
)
24
25
log_rank_n (events .table (row_limit = - 1 , max_src_column_width = 1000 ), logger .info , rank )
25
26
26
27
if show_op_names :
27
28
# Show non-cropped names, in the same order as in the table.
28
29
averages_sorted = torch .autograd .profiler .EventList (
29
- sorted (averages , key = lambda evt : evt . self_cuda_time_total , reverse = True )
30
+ sorted (averages , key = lambda evt : getattr ( evt , var_name ) , reverse = True )
30
31
)
31
32
for entry in averages_sorted :
32
33
log_rank_n (entry .key , logger .info , rank )
33
34
34
35
# Try to avoid name cropping, still hard-coded to max 55 characters
35
- log_rank_n (
36
- averages .table (sort_by = "self_cuda_time_total" , row_limit = - 1 , max_src_column_width = 1000 ), logger .info , rank
37
- )
36
+ log_rank_n (averages .table (sort_by = var_name , row_limit = - 1 , max_src_column_width = 1000 ), logger .info , rank )
37
+
38
+ # Store results for future use.
39
+ p .bc_profile_result = p .profiler .function_events
38
40
39
41
return trace_fn
40
42
@@ -45,6 +47,7 @@ def get_profiler(
45
47
cycles : int ,
46
48
full_trace : bool = False ,
47
49
show_op_names : bool = False ,
50
+ cpu = False ,
48
51
) -> Union [torch .profiler .profile , contextlib .nullcontext ]:
49
52
schedule = torch .profiler .schedule (
50
53
# Warmup is a must if measuring speed as it's when all the optimizations are performed
@@ -57,6 +60,7 @@ def get_profiler(
57
60
)
58
61
return torch .profiler .profile (
59
62
schedule = schedule ,
60
- activities = [torch .profiler .ProfilerActivity .CUDA ],
61
- on_trace_ready = get_trace_fn (full_trace , show_op_names ),
63
+ activities = [torch .profiler .ProfilerActivity .CPU if cpu else torch .profiler .ProfilerActivity .CUDA ],
64
+ on_trace_ready = get_trace_fn (full_trace , show_op_names , cpu = cpu ),
65
+ with_modules = True ,
62
66
)
0 commit comments