Skip to content

Operator level microbenchmarking #3154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
TestTowerCollectionSparseNNConfig,
TestTowerSparseNNConfig,
)
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf
from torchrec.distributed.benchmark.benchmark_utils import (
benchmark_func,
benchmark_operators,
cmd_conf,
)
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner import Topology
Expand Down Expand Up @@ -110,6 +114,9 @@ class RunOptions:
sparse_lr: float = 0.1
sparse_momentum: Optional[float] = None
sparse_weight_decay: Optional[float] = None
benchmark_operators: bool = False
target_operators: Optional[List[str]] = None
limit_operator_results: int = 10


@dataclass
Expand Down Expand Up @@ -379,10 +386,11 @@ def _func_to_benchmark(
if jit_suffix
else type(pipeline).__name__
)

result = benchmark_func(
name=name,
bench_inputs=bench_inputs, # pyre-ignore
prof_inputs=bench_inputs, # pyre-ignore
bench_inputs=bench_inputs, # pyre-ignore[6]
prof_inputs=bench_inputs, # pyre-ignore[6]
num_benchmarks=5,
num_profiles=2,
profile_dir=run_option.profile,
Expand All @@ -393,6 +401,19 @@ def _func_to_benchmark(
)
results.append(result)

if run_option.benchmark_operators:
op_results = benchmark_operators(
func_to_benchmark=pipeline,
bench_inputs=bench_inputs,
num_benchmarks=5,
device_type="cuda",
target_operators=run_option.target_operators,
is_pipeline=True,
rank=rank,
limit_results=run_option.limit_operator_results,
)
results.extend(op_results)

if rank == 0:
for result in results:
print(result)
Expand Down
77 changes: 77 additions & 0 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,83 @@ def trace_handler(prof) -> None:
)


def benchmark_operators(
func_to_benchmark: Any, # pyre-ignore[2]
bench_inputs: List[Any], # pyre-ignore[2]
num_benchmarks: int,
device_type: str = "cuda",
target_operators: Optional[List[str]] = None,
is_pipeline: bool = False,
rank: int = -1,
limit_results: int = 10,
) -> List[BenchmarkResult]:
activities = [torch.profiler.ProfilerActivity.CPU]
if device_type == "cuda":
activities.append(torch.profiler.ProfilerActivity.CUDA)

results = []
elapsed_times = {}
peak_memory_usage = {}

for _ in range(num_benchmarks):
with torch.profiler.profile(
activities=activities,
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
) as prof:
if is_pipeline:
dataloader = iter(bench_inputs)
while True:
try:
func_to_benchmark.progress(dataloader)
except StopIteration:
break
else:
for bench_input in bench_inputs:
func_to_benchmark(bench_input)

for evt in prof.key_averages():
if evt.key not in elapsed_times:
elapsed_times[evt.key] = []
peak_memory_usage[evt.key] = 0

elapsed_times[evt.key].append(evt.self_device_time_total / 1e3)
peak_memory_usage[evt.key] = max(
peak_memory_usage[evt.key], evt.self_device_memory_usage
)

for op in elapsed_times:
if target_operators is not None and op not in target_operators:
continue

mem_stats = [
MemoryStats(
rank=rank,
malloc_retries=-1, # Not supported in profiler
max_mem_allocated_mbs=peak_memory_usage[op] / 1024 / 1024,
max_mem_reserved_mbs=-1, # Not supported in profiler
)
]

results.append(
BenchmarkResult(
short_name=f"operator_{op}",
elapsed_time=torch.tensor(elapsed_times[op], dtype=torch.float),
mem_stats=mem_stats,
rank=rank,
)
)

sorted_results = sorted(
results, key=lambda x: x.runtime_percentile(90), reverse=True
)

return sorted_results[:limit_results]


def benchmark_type_name(compile_mode: CompileMode, sharding_type: ShardingType) -> str:
if sharding_type == ShardingType.TABLE_WISE:
name = "tw-sharded"
Expand Down
Loading