Skip to content

Commit

Permalink
Add quantiles parameter to do_bench_cudagraph (#4388)
Browse files Browse the repository at this point in the history
This makes its interface more similar to `do_bench`, making it easier to
switch between the two.
  • Loading branch information
int3 authored Jul 26, 2024
1 parent a51de76 commit 59a13eb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
5 changes: 2 additions & 3 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,10 @@ def kernel_call():
if self.use_cuda_graph:
import torch
with torch.cuda.stream(torch.cuda.Stream()):
bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median")
return bench_res
return do_bench_cudagraph(kernel_call, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
except (OutOfResources, CompileTimeAssertionFailure):
return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")]
return [float("inf"), float("inf"), float("inf")]

def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
Expand Down
22 changes: 13 additions & 9 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@ def nvsmi(attrs):
return ret


def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"):
def _summarize_statistics(times, quantiles, return_mode):
import torch
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
return getattr(torch, return_mode)(times).item()


def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
"""
Benchmark the runtime of the provided function.
Expand Down Expand Up @@ -77,8 +87,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"):
end_event.record()
torch.cuda.synchronize()
ret += [start_event.elapsed_time(end_event) / n_repeat]
times = torch.tensor(ret)
return getattr(torch, return_mode)(times).item()
return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)


def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
Expand Down Expand Up @@ -152,12 +161,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
return getattr(torch, return_mode)(times).item()
return _summarize_statistics(times, quantiles, return_mode)


def assert_close(x, y, atol=None, rtol=None, err_msg=''):
Expand Down

0 comments on commit 59a13eb

Please sign in to comment.