Skip to content

Commit 91d5ef5

Browse files
[Profiler] Adds CUPTI profiler support (#936)
* [Profiler]Adds CUPTI profiler support * format * rafactor cupti profiler * format * rafactor * rafactor * fix lint * fix lint * refactor * add profiler tests --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent ac8c9af commit 91d5ef5

File tree

5 files changed

+221
-57
lines changed

5 files changed

+221
-57
lines changed

examples/gemm/example_gemm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def main():
5151
print("CUDA Source:")
5252
print(kernel.get_kernel_source())
5353

54+
# benchmark
55+
profiler = kernel.get_profiler()
56+
latency = profiler.do_bench(backend="cupti")
57+
# latency = profiler.do_bench()
58+
print(f"tilelang Latency: {latency}ms")
59+
5460

5561
if __name__ == "__main__":
5662
main()
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import tilelang
2+
import tilelang.language as T
3+
4+
5+
@tilelang.jit(out_idx=[-1])
6+
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
7+
8+
@T.prim_func
9+
def gemm(
10+
A: T.Tensor((M, K), dtype),
11+
B: T.Tensor((K, N), dtype),
12+
C: T.Tensor((M, N), dtype),
13+
):
14+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
15+
A_shared = T.alloc_shared((block_M, block_K), dtype)
16+
B_shared = T.alloc_shared((block_K, block_N), dtype)
17+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
18+
19+
T.clear(C_local)
20+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
21+
T.copy(A[by * block_M, k * block_K], A_shared)
22+
T.copy(B[k * block_K, bx * block_N], B_shared)
23+
T.gemm(A_shared, B_shared, C_local)
24+
25+
T.copy(C_local, C[by * block_M, bx * block_N])
26+
27+
return gemm
28+
29+
30+
def test_profiler():
31+
kernel = matmul(1024, 1024, 1024, 128, 128, 32)
32+
33+
import torch
34+
35+
a = torch.randn(1024, 1024).cuda().half()
36+
b = torch.randn(1024, 1024).cuda().half()
37+
38+
c = kernel(a, b)
39+
ref_c = a @ b
40+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
41+
42+
# benchmark
43+
profiler = kernel.get_profiler()
44+
45+
# use cupti backend
46+
cupti_latency = profiler.do_bench(backend="cupti")
47+
48+
# use event backend
49+
event_latency = profiler.do_bench(backend="event")
50+
print(f"cupti Latency: {cupti_latency}ms")
51+
print(f"event Latency: {event_latency}ms")
52+
53+
54+
if __name__ == "__main__":
55+
tilelang.testing.main()

tilelang/jit/adapter/cython/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def get_cached_lib(source_code: str) -> Tuple[Optional[ctypes.CDLL], Path]:
175175

176176

177177
class CythonKernelAdapter(BaseKernelAdapter):
178-
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes.
178+
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using cython.
179179
180180
This adapter handles:
181181
1. Converting TIR functions to compiled CUDA libraries

tilelang/profiler/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""The profiler and convert to torch utils"""
22

3-
from typing import List, Optional, Callable, Any
3+
from typing import List, Optional, Callable, Any, Literal
44
from functools import partial
55
import torch
66
from contextlib import suppress
@@ -223,6 +223,9 @@ def do_bench(
223223
n_warmup: int = 1,
224224
n_repeat: int = 1,
225225
input_tensors: List[torch.Tensor] = None,
226+
backend: Literal["event", "cupti"] = "event",
227+
quantiles: Optional[List[float]] = None,
228+
return_mode: Literal["min", "max", "mean", "median"] = "mean",
226229
) -> float:
227230
"""Benchmarks the execution time of a given function.
228231
@@ -251,6 +254,9 @@ def do_bench(
251254
rep=rep,
252255
_n_warmup=n_warmup,
253256
_n_repeat=n_repeat,
257+
quantiles=quantiles,
258+
backend=backend,
259+
return_mode=return_mode,
254260
)
255261
elif profiler == "tvm":
256262
assert func is not None, "func should not be None"

tilelang/profiler/bench.py

Lines changed: 152 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,106 @@
1-
"""The profiler and convert to torch utils"""
1+
"""Profiler and benchmarking utilities for PyTorch functions."""
22

3-
import torch
3+
import os
4+
import sys
45
from typing import Callable, List, Literal, Optional, Union
56

7+
import torch
8+
9+
10+
class suppress_stdout_stderr:
11+
"""Context manager to suppress stdout and stderr output.
12+
13+
Source: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/testing/bench.py
14+
"""
15+
16+
def __enter__(self):
17+
# Open null device files
18+
self.outnull_file = open(os.devnull, 'w')
19+
self.errnull_file = open(os.devnull, 'w')
20+
21+
# Save original file descriptors
22+
self.old_stdout_fileno_undup = sys.stdout.fileno()
23+
self.old_stderr_fileno_undup = sys.stderr.fileno()
24+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
25+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
26+
27+
# Save original stdout/stderr objects
28+
self.old_stdout = sys.stdout
29+
self.old_stderr = sys.stderr
30+
31+
# Redirect file descriptors and streams to null device
32+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
33+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
34+
sys.stdout = self.outnull_file
35+
sys.stderr = self.errnull_file
36+
37+
return self
38+
39+
def __exit__(self, *_):
40+
# Restore original stdout/stderr objects
41+
sys.stdout = self.old_stdout
42+
sys.stderr = self.old_stderr
43+
44+
# Restore original file descriptors
45+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
46+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
47+
48+
# Close duplicated file descriptors
49+
os.close(self.old_stdout_fileno)
50+
os.close(self.old_stderr_fileno)
51+
52+
# Close null device files
53+
self.outnull_file.close()
54+
self.errnull_file.close()
55+
656

757
def do_bench(
858
fn: Callable,
959
warmup: float = 25,
1060
rep: float = 100,
1161
_n_warmup: int = 0,
1262
_n_repeat: int = 0,
13-
grad_to_none: Optional[List[torch.Tensor]] = None,
1463
quantiles: Optional[List[float]] = None,
1564
fast_flush: bool = True,
65+
backend: Literal["event", "cupti"] = "event",
1666
return_mode: Literal["min", "max", "mean", "median"] = "mean",
1767
) -> Union[float, List[float]]:
18-
"""Benchmarks the runtime of a PyTorch function.
68+
"""Benchmark the runtime of a PyTorch function with L2 cache management.
1969
20-
This function handles:
21-
- L2 cache flushing between runs for consistent timing
22-
- Automatic warmup and repeat count calculation
23-
- Optional gradient clearing for backward passes
24-
- Multiple measurement modes (mean, median, min, max)
70+
This function provides accurate GPU kernel timing by:
71+
- Clearing L2 cache between runs for consistent measurements
72+
- Auto-calculating warmup and repeat counts based on kernel runtime
73+
- Supporting multiple profiling backends (CUDA events or CUPTI)
74+
- Offering flexible result aggregation (mean/median/min/max/quantiles)
2575
2676
Args:
2777
fn: Function to benchmark
28-
warmup: Target warmup time in milliseconds
29-
rep: Target number of repetitions
30-
_n_warmup: Override for number of warmup iterations
31-
_n_repeat: Override for number of timing iterations
32-
grad_to_none: Tensors whose gradients should be cleared between runs
33-
quantiles: Optional performance percentiles to compute
34-
fast_flush: Whether to use faster L2 cache flushing
35-
return_mode: How to aggregate timing results ("mean", "median", "min", "max")
78+
warmup: Target warmup time in milliseconds (default: 25)
79+
rep: Target total benchmark time in milliseconds (default: 100)
80+
_n_warmup: Manual override for warmup iterations (default: 0 = auto)
81+
_n_repeat: Manual override for benchmark iterations (default: 0 = auto)
82+
quantiles: Performance percentiles to compute (e.g., [0.5, 0.95])
83+
fast_flush: Use faster L2 cache flush with int32 vs int8 (default: True)
84+
backend: Profiler backend - "event" (CUDA events) or "cupti" (default: "event")
85+
return_mode: Result aggregation method - "mean", "median", "min", or "max"
3686
3787
Returns:
38-
float: Aggregated runtime in milliseconds
88+
Runtime in milliseconds (float) or list of quantile values if quantiles specified
3989
"""
40-
assert return_mode in ["min", "max", "mean", "median"]
90+
assert return_mode in ["min", "max", "mean", "median"], \
91+
f"Invalid return_mode: {return_mode}"
92+
93+
# Initial function call and synchronization
4194
fn()
4295
torch.cuda.synchronize()
4396

44-
# We maintain a buffer of 256 MB that we clear
45-
# before each kernel call to make sure that the L2
46-
# doesn't contain any input data before the run
47-
if fast_flush:
48-
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
49-
else:
50-
cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda")
97+
# Create L2 cache flush buffer (256 MB)
98+
# Fast flush uses int32 (4 bytes), regular uses int8 (1 byte)
99+
cache_size = int(256e6 // 4) if fast_flush else int(256e6)
100+
cache_dtype = torch.int if fast_flush else torch.int8
101+
cache = torch.empty(cache_size, dtype=cache_dtype, device="cuda")
51102

52-
# Estimate the runtime of the function
103+
# Estimate kernel runtime with 5 iterations
53104
start_event = torch.cuda.Event(enable_timing=True)
54105
end_event = torch.cuda.Event(enable_timing=True)
55106
start_event.record()
@@ -60,41 +111,87 @@ def do_bench(
60111
torch.cuda.synchronize()
61112
estimate_ms = start_event.elapsed_time(end_event) / 5
62113

63-
# compute number of warmup and repeat
64-
n_warmup = max(1, int(warmup / estimate_ms))
65-
n_repeat = max(1, int(rep / estimate_ms))
66-
if _n_warmup > 0:
67-
n_warmup = _n_warmup
68-
if _n_repeat > 0:
69-
n_repeat = _n_repeat
70-
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
71-
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
72-
# Warm-up
114+
# Calculate warmup and repeat counts (minimum 1 iteration each)
115+
n_warmup = _n_warmup if _n_warmup > 0 else max(1, int(warmup / estimate_ms))
116+
n_repeat = _n_repeat if _n_repeat > 0 else max(1, int(rep / estimate_ms))
117+
118+
# Warmup phase
73119
for _ in range(n_warmup):
74120
fn()
75-
# Benchmark
121+
122+
# Benchmarking phase
123+
if backend == "event":
124+
return _bench_with_cuda_events(fn, cache, n_repeat, quantiles, return_mode)
125+
elif backend == "cupti":
126+
return _bench_with_cupti(fn, cache, n_repeat)
127+
else:
128+
raise ValueError(f"Unknown profiler backend: {backend}")
129+
130+
131+
def _bench_with_cuda_events(
132+
fn: Callable,
133+
cache: torch.Tensor,
134+
n_repeat: int,
135+
quantiles: Optional[List[float]],
136+
return_mode: str,
137+
) -> Union[float, List[float]]:
138+
"""Benchmark using CUDA events for timing."""
139+
# Create timing events
140+
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
141+
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
142+
143+
# Run benchmark iterations
76144
for i in range(n_repeat):
77-
# we don't want `fn` to accumulate gradient values
78-
# if it contains a backward pass. So we clear the
79-
# provided gradients
80-
if grad_to_none is not None:
81-
for x in grad_to_none:
82-
x.grad = None
83-
# we clear the L2 cache before each run
84-
cache.zero_()
85-
# record time of `fn`
86-
start_event[i].record()
145+
cache.zero_() # Clear L2 cache
146+
start_events[i].record()
87147
fn()
88-
end_event[i].record()
89-
# Record clocks
148+
end_events[i].record()
149+
150+
# Synchronize and collect timings
90151
torch.cuda.synchronize()
91152
times = torch.tensor(
92-
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
153+
[s.elapsed_time(e) for s, e in zip(start_events, end_events)],
93154
dtype=torch.float,
94155
)
156+
157+
# Return quantiles if requested
95158
if quantiles is not None:
96-
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
97-
if len(ret) == 1:
98-
ret = ret[0]
99-
return ret
159+
quantile_values = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
160+
return quantile_values[0] if len(quantile_values) == 1 else quantile_values
161+
162+
# Return aggregated result
100163
return getattr(torch, return_mode)(times).item()
164+
165+
166+
def _bench_with_cupti(
167+
fn: Callable,
168+
cache: torch.Tensor,
169+
n_repeat: int,
170+
) -> float:
171+
"""Benchmark using CUPTI profiler for detailed kernel timing."""
172+
with suppress_stdout_stderr():
173+
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
174+
profiler = torch.profiler.profile(
175+
activities=[torch.profiler.ProfilerActivity.CUDA],
176+
schedule=schedule,
177+
)
178+
179+
with profiler:
180+
for _ in range(2):
181+
for _ in range(n_repeat):
182+
cache.zero_()
183+
fn()
184+
profiler.step()
185+
186+
# Calculate average kernel time, excluding cache-clearing overhead
187+
total_cuda_time = 0.0
188+
excluded_time = 0.0
189+
excluded_kernels = "at::native::vectorized_elementwise"
190+
191+
for event in profiler.key_averages():
192+
total_cuda_time += event.self_device_time_total
193+
if excluded_kernels in event.key:
194+
excluded_time += event.self_device_time_total
195+
196+
kernel_time_us = (total_cuda_time - excluded_time) / n_repeat
197+
return kernel_time_us * 1e-3 # Convert microseconds to milliseconds

0 commit comments

Comments
 (0)