Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions superbench/benchmarks/model_benchmarks/pytorch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,58 @@ def _timer(self):
if self._gpu_available:
torch.cuda.synchronize()
return time.time()

def _benchmark(self):
"""Wrap super._benchmark with profiler context if enabled by environment variable.

Set SB_ENABLE_PYTORCH_PROFILER='1' to enable profiling.
"""
# Check if this is a Nvidia GPU
if not (torch.cuda.is_available() and torch.version.cuda is not None):
return super()._benchmark()

# Check if profiling is enabled via environment variable
enable_profiler = os.environ.get('SB_ENABLE_PYTORCH_PROFILER', '0') == '1'

if not enable_profiler:
# Run without profiling
return super()._benchmark()

# Run with profiling enabled
logger.info('PyTorch profiler enabled for model: {}'.format(self._name))
ret = None

from torch.profiler import profile, ProfilerActivity
from torch.autograd import DeviceType
import json

if self._local_rank is None:
local_rank = 0
else:
local_rank = self._local_rank

diag_agent_prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True)
dump_file_dir = os.environ.get('SB_TORCH_PROFILER_TRACE_DIR', '.')
diag_agent_dump_file_path = f'{dump_file_dir}/torch-profiler-sb-{self._name}-{local_rank}.json'
diag_agent_prof.__enter__()

ret = super()._benchmark()

diag_agent_prof.__exit__(None, None, None)
diag_agent_events = []
for event in diag_agent_prof.events():
if event.device_type != DeviceType.CPU:
continue
diag_agent_event = {
'name': event.name,
'input_shapes': event.input_shapes,
'input_values': event.concrete_inputs,
}
diag_agent_event['cpu_time'] = event.cpu_time
diag_agent_event['gpu_time'] = event.cuda_time
diag_agent_event['start_time'] = event.time_range.start
diag_agent_events.append(diag_agent_event)
with open(diag_agent_dump_file_path, 'w') as f:
json.dump(diag_agent_events, f, sort_keys=True)

return ret
40 changes: 34 additions & 6 deletions superbench/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,33 +131,61 @@ def __get_mode_command(self, benchmark_name, mode, timeout=None):
if timeout is not None:
exec_command = 'timeout {timeout} {command}'.format(timeout=timeout, command=exec_command)

# Enable nsys profiling based on environment variable
enable_nsys = os.environ.get('SB_ENABLE_NSYS', '') == '1'
trace_dir = os.environ.get('SB_NSYS_TRACE_DIR', self._sb_output_dir)

mode_command = exec_command
if mode.name == 'local':
mode_command = '{prefix} {command}'.format(
prefix=mode.prefix.format(proc_rank=mode.proc_rank, proc_num=mode.proc_num),
command=exec_command,
)
mode_command = f'PROC_RANK={mode.proc_rank} {mode_command.strip()}'
trace_command = (
f'nsys profile --output {trace_dir}/{benchmark_name}_{mode.proc_rank}_traces '
f'--backtrace none --sample none --force-overwrite true --cpuctxsw none --trace cuda,nvtx '
) if enable_nsys and mode.proc_rank == 0 else ''
# Build the command parts, only including trace if it's not empty
command_parts = []
prefix = mode.prefix.format(proc_rank=mode.proc_rank, proc_num=mode.proc_num)
if prefix:
command_parts.append(prefix)
if trace_command:
command_parts.append(trace_command)
command_parts.append(exec_command)
mode_command = ' '.join(command_parts)
mode_command = f'PROC_RANK={mode.proc_rank} {mode_command}'
elif mode.name == 'torch.distributed':
# TODO: replace with torch.distributed.run in v1.9
# TODO: only supports node_num=1 and node_num=all currently
torch_dist_params = '' if 'node_num' in mode and mode.node_num == 1 else \
torch_dist_params = (
'' if 'node_num' in mode and mode.node_num == 1 else
'--nnodes=$NNODES --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT '
)

nsys_prefix = (
f'nsys profile --output {trace_dir}/{benchmark_name}_traces '
f'--backtrace none --sample none --force-overwrite true --cpuctxsw none --trace cuda,nvtx '
) if enable_nsys else ''

mode_command = (
f'{nsys_prefix}'
f'torchrun'
f' --no_python --nproc_per_node={mode.proc_num} {torch_dist_params}{exec_command}'
f' superbench.benchmarks.{benchmark_name}.parameters.distributed_impl=ddp'
f' superbench.benchmarks.{benchmark_name}.parameters.distributed_backend=nccl'
)
elif mode.name == 'mpi':
trace_command = (
f'nsys profile --output {trace_dir}/{benchmark_name}_{mode.proc_rank}_traces '
f'--backtrace none --sample none --force-overwrite true --cpuctxsw none --trace cuda,nvtx '
) if enable_nsys else ''
mode_command = (
'{trace} '
'mpirun ' # use default OpenMPI in image
'-tag-output ' # tag mpi output with [jobid,rank]<stdout/stderr> prefix
'-allow-run-as-root ' # allow mpirun to run when executed by root user
'{host_list} ' # use prepared hostfile or specify nodes and launch {proc_num} processes on each node
'-bind-to numa ' # bind processes to numa
'{mca_list} {env_list} {command}'
).format(
trace=trace_command,
host_list=f'-host localhost:{mode.proc_num}' if 'node_num' in mode and mode.node_num == 1 else
f'-hostfile hostfile -map-by ppr:{mode.proc_num}:node' if 'host_list' not in mode else '-host ' +
','.join(f'{host}:{mode.proc_num}' for host in mode.host_list),
Expand Down
Loading