1- """The profiler and convert to torch utils """
1+ """Profiler and benchmarking utilities for PyTorch functions. """
22
3- import torch
3+ import os
4+ import sys
45from typing import Callable , List , Literal , Optional , Union
56
7+ import torch
8+
9+
10+ class suppress_stdout_stderr :
11+ """Context manager to suppress stdout and stderr output.
12+
13+ Source: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/testing/bench.py
14+ """
15+
16+ def __enter__ (self ):
17+ # Open null device files
18+ self .outnull_file = open (os .devnull , 'w' )
19+ self .errnull_file = open (os .devnull , 'w' )
20+
21+ # Save original file descriptors
22+ self .old_stdout_fileno_undup = sys .stdout .fileno ()
23+ self .old_stderr_fileno_undup = sys .stderr .fileno ()
24+ self .old_stdout_fileno = os .dup (sys .stdout .fileno ())
25+ self .old_stderr_fileno = os .dup (sys .stderr .fileno ())
26+
27+ # Save original stdout/stderr objects
28+ self .old_stdout = sys .stdout
29+ self .old_stderr = sys .stderr
30+
31+ # Redirect file descriptors and streams to null device
32+ os .dup2 (self .outnull_file .fileno (), self .old_stdout_fileno_undup )
33+ os .dup2 (self .errnull_file .fileno (), self .old_stderr_fileno_undup )
34+ sys .stdout = self .outnull_file
35+ sys .stderr = self .errnull_file
36+
37+ return self
38+
39+ def __exit__ (self , * _ ):
40+ # Restore original stdout/stderr objects
41+ sys .stdout = self .old_stdout
42+ sys .stderr = self .old_stderr
43+
44+ # Restore original file descriptors
45+ os .dup2 (self .old_stdout_fileno , self .old_stdout_fileno_undup )
46+ os .dup2 (self .old_stderr_fileno , self .old_stderr_fileno_undup )
47+
48+ # Close duplicated file descriptors
49+ os .close (self .old_stdout_fileno )
50+ os .close (self .old_stderr_fileno )
51+
52+ # Close null device files
53+ self .outnull_file .close ()
54+ self .errnull_file .close ()
55+
656
757def do_bench (
858 fn : Callable ,
959 warmup : float = 25 ,
1060 rep : float = 100 ,
1161 _n_warmup : int = 0 ,
1262 _n_repeat : int = 0 ,
13- grad_to_none : Optional [List [torch .Tensor ]] = None ,
1463 quantiles : Optional [List [float ]] = None ,
1564 fast_flush : bool = True ,
65+ backend : Literal ["event" , "cupti" ] = "event" ,
1666 return_mode : Literal ["min" , "max" , "mean" , "median" ] = "mean" ,
1767) -> Union [float , List [float ]]:
18- """Benchmarks the runtime of a PyTorch function.
68+ """Benchmark the runtime of a PyTorch function with L2 cache management .
1969
20- This function handles :
21- - L2 cache flushing between runs for consistent timing
22- - Automatic warmup and repeat count calculation
23- - Optional gradient clearing for backward passes
24- - Multiple measurement modes (mean, median, min, max)
70+ This function provides accurate GPU kernel timing by :
71+ - Clearing L2 cache between runs for consistent measurements
72+ - Auto-calculating warmup and repeat counts based on kernel runtime
73+ - Supporting multiple profiling backends (CUDA events or CUPTI)
74+ - Offering flexible result aggregation (mean/ median/ min/ max/quantiles )
2575
2676 Args:
2777 fn: Function to benchmark
28- warmup: Target warmup time in milliseconds
29- rep: Target number of repetitions
30- _n_warmup: Override for number of warmup iterations
31- _n_repeat: Override for number of timing iterations
32- grad_to_none: Tensors whose gradients should be cleared between runs
33- quantiles: Optional performance percentiles to compute
34- fast_flush: Whether to use faster L2 cache flushing
35- return_mode: How to aggregate timing results ( "mean", "median", "min", "max")
78+ warmup: Target warmup time in milliseconds (default: 25)
79+ rep: Target total benchmark time in milliseconds (default: 100)
80+ _n_warmup: Manual override for warmup iterations (default: 0 = auto)
81+ _n_repeat: Manual override for benchmark iterations (default: 0 = auto)
82+ quantiles: Performance percentiles to compute (e.g., [0.5, 0.95])
83+ fast_flush: Use faster L2 cache flush with int32 vs int8 (default: True)
84+ backend: Profiler backend - "event" (CUDA events) or "cupti" (default: "event")
85+ return_mode: Result aggregation method - "mean", "median", "min", or "max"
3686
3787 Returns:
38- float: Aggregated runtime in milliseconds
88+ Runtime in milliseconds ( float) or list of quantile values if quantiles specified
3989 """
40- assert return_mode in ["min" , "max" , "mean" , "median" ]
90+ assert return_mode in ["min" , "max" , "mean" , "median" ], \
91+ f"Invalid return_mode: { return_mode } "
92+
93+ # Initial function call and synchronization
4194 fn ()
4295 torch .cuda .synchronize ()
4396
44- # We maintain a buffer of 256 MB that we clear
45- # before each kernel call to make sure that the L2
46- # doesn't contain any input data before the run
47- if fast_flush :
48- cache = torch .empty (int (256e6 // 4 ), dtype = torch .int , device = "cuda" )
49- else :
50- cache = torch .empty (int (256e6 ), dtype = torch .int8 , device = "cuda" )
97+ # Create L2 cache flush buffer (256 MB)
98+ # Fast flush uses int32 (4 bytes), regular uses int8 (1 byte)
99+ cache_size = int (256e6 // 4 ) if fast_flush else int (256e6 )
100+ cache_dtype = torch .int if fast_flush else torch .int8
101+ cache = torch .empty (cache_size , dtype = cache_dtype , device = "cuda" )
51102
52- # Estimate the runtime of the function
103+ # Estimate kernel runtime with 5 iterations
53104 start_event = torch .cuda .Event (enable_timing = True )
54105 end_event = torch .cuda .Event (enable_timing = True )
55106 start_event .record ()
@@ -60,41 +111,87 @@ def do_bench(
60111 torch .cuda .synchronize ()
61112 estimate_ms = start_event .elapsed_time (end_event ) / 5
62113
63- # compute number of warmup and repeat
64- n_warmup = max (1 , int (warmup / estimate_ms ))
65- n_repeat = max (1 , int (rep / estimate_ms ))
66- if _n_warmup > 0 :
67- n_warmup = _n_warmup
68- if _n_repeat > 0 :
69- n_repeat = _n_repeat
70- start_event = [torch .cuda .Event (enable_timing = True ) for i in range (n_repeat )]
71- end_event = [torch .cuda .Event (enable_timing = True ) for i in range (n_repeat )]
72- # Warm-up
114+ # Calculate warmup and repeat counts (minimum 1 iteration each)
115+ n_warmup = _n_warmup if _n_warmup > 0 else max (1 , int (warmup / estimate_ms ))
116+ n_repeat = _n_repeat if _n_repeat > 0 else max (1 , int (rep / estimate_ms ))
117+
118+ # Warmup phase
73119 for _ in range (n_warmup ):
74120 fn ()
75- # Benchmark
121+
122+ # Benchmarking phase
123+ if backend == "event" :
124+ return _bench_with_cuda_events (fn , cache , n_repeat , quantiles , return_mode )
125+ elif backend == "cupti" :
126+ return _bench_with_cupti (fn , cache , n_repeat )
127+ else :
128+ raise ValueError (f"Unknown profiler backend: { backend } " )
129+
130+
131+ def _bench_with_cuda_events (
132+ fn : Callable ,
133+ cache : torch .Tensor ,
134+ n_repeat : int ,
135+ quantiles : Optional [List [float ]],
136+ return_mode : str ,
137+ ) -> Union [float , List [float ]]:
138+ """Benchmark using CUDA events for timing."""
139+ # Create timing events
140+ start_events = [torch .cuda .Event (enable_timing = True ) for _ in range (n_repeat )]
141+ end_events = [torch .cuda .Event (enable_timing = True ) for _ in range (n_repeat )]
142+
143+ # Run benchmark iterations
76144 for i in range (n_repeat ):
77- # we don't want `fn` to accumulate gradient values
78- # if it contains a backward pass. So we clear the
79- # provided gradients
80- if grad_to_none is not None :
81- for x in grad_to_none :
82- x .grad = None
83- # we clear the L2 cache before each run
84- cache .zero_ ()
85- # record time of `fn`
86- start_event [i ].record ()
145+ cache .zero_ () # Clear L2 cache
146+ start_events [i ].record ()
87147 fn ()
88- end_event [i ].record ()
89- # Record clocks
148+ end_events [i ].record ()
149+
150+ # Synchronize and collect timings
90151 torch .cuda .synchronize ()
91152 times = torch .tensor (
92- [s .elapsed_time (e ) for s , e in zip (start_event , end_event )],
153+ [s .elapsed_time (e ) for s , e in zip (start_events , end_events )],
93154 dtype = torch .float ,
94155 )
156+
157+ # Return quantiles if requested
95158 if quantiles is not None :
96- ret = torch .quantile (times , torch .tensor (quantiles , dtype = torch .float )).tolist ()
97- if len (ret ) == 1 :
98- ret = ret [ 0 ]
99- return ret
159+ quantile_values = torch .quantile (times , torch .tensor (quantiles , dtype = torch .float )).tolist ()
160+ return quantile_values [ 0 ] if len (quantile_values ) == 1 else quantile_values
161+
162+ # Return aggregated result
100163 return getattr (torch , return_mode )(times ).item ()
164+
165+
166+ def _bench_with_cupti (
167+ fn : Callable ,
168+ cache : torch .Tensor ,
169+ n_repeat : int ,
170+ ) -> float :
171+ """Benchmark using CUPTI profiler for detailed kernel timing."""
172+ with suppress_stdout_stderr ():
173+ schedule = torch .profiler .schedule (wait = 1 , warmup = 0 , active = 1 , repeat = 1 )
174+ profiler = torch .profiler .profile (
175+ activities = [torch .profiler .ProfilerActivity .CUDA ],
176+ schedule = schedule ,
177+ )
178+
179+ with profiler :
180+ for _ in range (2 ):
181+ for _ in range (n_repeat ):
182+ cache .zero_ ()
183+ fn ()
184+ profiler .step ()
185+
186+ # Calculate average kernel time, excluding cache-clearing overhead
187+ total_cuda_time = 0.0
188+ excluded_time = 0.0
189+ excluded_kernels = "at::native::vectorized_elementwise"
190+
191+ for event in profiler .key_averages ():
192+ total_cuda_time += event .self_device_time_total
193+ if excluded_kernels in event .key :
194+ excluded_time += event .self_device_time_total
195+
196+ kernel_time_us = (total_cuda_time - excluded_time ) / n_repeat
197+ return kernel_time_us * 1e-3 # Convert microseconds to milliseconds
0 commit comments