Skip to content

Support displaying separate send and recv time #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/test_low_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,17 @@ def test_func(zero_copy: bool, return_recv_hook: bool):
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
bench_output = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True)
suppress_kineto_output=True, duplicate_name_period=2 if return_recv_hook else None)
if not return_recv_hook:
dispatch_t, combine_t = bench_output
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True)
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us', flush=True)
dispatch_t, combine_t, detail_times = bench_output
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} = {detail_times["dispatch"][0] * 1e6:.2f} + {detail_times["dispatch"][1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} = {detail_times["combine"][0] * 1e6:.2f} + {detail_times["combine"][1] * 1e6:.2f} us', flush=True)

return hash_value

Expand Down
31 changes: 29 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import inspect
import json
import tempfile
from pathlib import Path

import numpy as np
import os
import sys
Expand Down Expand Up @@ -152,7 +156,8 @@ def __exit__(self, *_):


def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False):
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False,
duplicate_name_period: Optional[int] = None):
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
Expand Down Expand Up @@ -194,8 +199,30 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
break
break
return tuple(kernel_times) if is_tupled else kernel_times[0]

if duplicate_name_period is None:
return tuple(kernel_times) if is_tupled else kernel_times[0]
else:
detail_times = extract_detail_times_from_prof(prof, kernel_names=kernel_names, duplicate_name_period=duplicate_name_period)
return tuple(kernel_times) + (detail_times,)


def extract_detail_times_from_prof(prof, kernel_names, duplicate_name_period: int):
with tempfile.NamedTemporaryFile(suffix=".json") as tmp:
prof.export_chrome_trace(tmp.name)
profile_data = json.loads(Path(tmp.name).read_text())

ans = {}
for kernel_name in kernel_names:
name_matcher = f'::{kernel_name}<'
events = [e for e in profile_data["traceEvents"] if name_matcher in e["name"]]
events = sorted(events, key=lambda e: e["ts"])
durations = [e["dur"] / 1e6 for e in events]
ans[kernel_name] = [list_mean(durations[i::duplicate_name_period]) for i in range(duplicate_name_period)]
return ans

def list_mean(xs):
return sum(xs) / len(xs)

def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item()