Skip to content

Commit

Permalink
Change default gpu metric backend (pytorch#2501)
Browse files Browse the repository at this point in the history
Summary:
The current GPU memory metric backend includes dcgm and nvml. They are reported from hardware and should be accurate. This PR adds the native torch way to collect GPU memory usage. It uses `torch.cuda.max_memory_allocated()`. The benefit is that it has lower overhead and is accurate on a shared GPU server when there are mutliple GPU processes from other users. It is because we don't implement the process filter for the other two backends.

Use `--metrics-gpu-backend torch` to set the backend.

Pull Request resolved: pytorch#2501

Reviewed By: xuzhao9

Differential Revision: D64253410

Pulled By: FindHao

fbshipit-source-id: 09b0579846a6830e0e9735e8daeba4abd88bab17
  • Loading branch information
FindHao authored and facebook-github-bot committed Oct 16, 2024
1 parent 21cc30d commit c396191
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 36 deletions.
23 changes: 11 additions & 12 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,18 +477,17 @@ def main() -> None:
)
parser.add_argument(
"--metrics-gpu-backend",
choices=["dcgm", "default"],
default="default",
choices=["torch", "nvml", "dcgm"],
default="torch",
help="""
Specify the backend [dcgm, default] to collect metrics.
In default mode, the latency(execution time) is collected by time.time_ns() and it is always enabled.
Optionally, - you can specify cpu peak memory usage by --metrics cpu_peak_mem, and it is collected by psutil.Process().
- you can specify gpu peak memory usage by --metrics gpu_peak_mem, and it is collected by nvml library.
- you can specify flops by --metrics flops, and it is collected by fvcore.
In dcgm mode, the latency(execution time) is collected by time.time_ns() and it is always enabled.
Optionally,
- you can specify cpu peak memory usage by --metrics cpu_peak_mem, and it is collected by psutil.Process().
- you can specify cpu and gpu peak memory usage by --metrics cpu_peak_mem,gpu_peak_mem, and they are collected by dcgm library.""",
Specify the backend [torch, nvml, dcgm] to collect metrics. In all modes,
the latency (execution time) is always collected using `time.time_ns()`. The CPU
and GPU peak memory usage metrics are optional. The CPU peak memory usage is
collected by `psutil.Process()` in all modes. In nvml mode, the GPU peak memory
usage is collected by the `nvml` library. In dcgm mode, the GPU peak memory usage is
collected by the `dcgm` library. In torch mode, the GPU peak memory usage is collected
by `torch.cuda.max_memory_allocated()`.
""",
)
args, extra_args = parser.parse_known_args()
if args.cudastreams and not args.device == "cuda":
Expand Down Expand Up @@ -541,7 +540,7 @@ def main() -> None:
)

check_dcgm()
elif "gpu_peak_mem" in metrics_needed:
elif metrics_gpu_backend == "nvml":
from torchbenchmark._components.model_analyzer.TorchBenchAnalyzer import (
check_nvml,
)
Expand Down
71 changes: 50 additions & 21 deletions torchbenchmark/util/experiment/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@

import copy
import dataclasses
import os
import pathlib
import time
from typing import List, Optional, Tuple, Union

import psutil

import torch

from torchbenchmark import ModelTask
from torchbenchmark.util.experiment.instantiator import TorchBenchModelConfig
from torchbenchmark.util.model import BenchmarkModel
Expand All @@ -31,6 +35,11 @@ class TorchBenchModelMetrics:
model_flops: Optional[float]


def maybe_synchronize(device: str):
if device == "cuda":
torch.cuda.synchronize()


def get_latencies(
func, device: str, nwarmup=WARMUP_ROUNDS, num_iter=BENCHMARK_ITERS
) -> List[float]:
Expand Down Expand Up @@ -62,25 +71,30 @@ def get_peak_memory(
num_iter=MEMPROF_ITER,
export_metrics_file="",
metrics_needed=[],
metrics_gpu_backend="dcgm",
metrics_gpu_backend="torch",
cpu_monitored_pid=None,
) -> Tuple[Optional[float], Optional[str], Optional[float]]:
"Run one step of the model, and return the peak memory in MB."
from torchbenchmark._components.model_analyzer.TorchBenchAnalyzer import (
ModelAnalyzer,
)

new_metrics_needed = [
_ for _ in metrics_needed if _ in ["cpu_peak_mem", "gpu_peak_mem"]
]
if not new_metrics_needed:
raise ValueError(
f"Expected metrics_needed to be non-empty, get: {metrics_needed}"
)
mem_model_analyzer = ModelAnalyzer(
export_metrics_file, new_metrics_needed, metrics_gpu_backend, cpu_monitored_pid
)
continue_num_iter = BENCHMARK_ITERS - num_iter
if metrics_gpu_backend in ["dcgm", "nvml"]:
from torchbenchmark._components.model_analyzer.TorchBenchAnalyzer import (
ModelAnalyzer,
)

mem_model_analyzer = ModelAnalyzer(
export_metrics_file,
new_metrics_needed,
metrics_gpu_backend,
cpu_monitored_pid,
)
else:
mem_model_analyzer = None

def work_func():
if device == "cuda":
Expand All @@ -99,22 +113,37 @@ def work_func():
num_iter = BENCHMARK_ITERS
else:
num_iter = MEMPROF_ITER
mem_model_analyzer.start_monitor()

for _i in range(num_iter):
work_func()
mem_model_analyzer.stop_monitor()
mem_model_analyzer.aggregate()
device_id = None
gpu_peak_mem = None
cpu_peak_mem = None
if "gpu_peak_mem" in metrics_needed:
device_id, gpu_peak_mem = mem_model_analyzer.calculate_gpu_peak_mem()
if "cpu_peak_mem" in metrics_needed:
cpu_peak_mem = mem_model_analyzer.calculate_cpu_peak_mem()
if export_metrics_file:
mem_model_analyzer.update_export_name("_peak_memory")
mem_model_analyzer.export_all_records_to_csv()

if mem_model_analyzer:
mem_model_analyzer.start_monitor()
for _i in range(num_iter):
work_func()
mem_model_analyzer.stop_monitor()
mem_model_analyzer.aggregate()

if "gpu_peak_mem" in metrics_needed:
device_id, gpu_peak_mem = mem_model_analyzer.calculate_gpu_peak_mem()
if "cpu_peak_mem" in metrics_needed:
cpu_peak_mem = mem_model_analyzer.calculate_cpu_peak_mem()
if export_metrics_file:
mem_model_analyzer.update_export_name("_peak_memory")
mem_model_analyzer.export_all_records_to_csv()
else:
if device == "cuda":
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
for _ in range(num_iter):
work_func()
if device == "cuda":
device_id = torch.cuda.current_device()
gpu_peak_mem = torch.cuda.max_memory_allocated() / 10**9
total = psutil.virtual_memory().total
percentage = psutil.Process(os.getpid()).memory_percent()
cpu_peak_mem = percentage * total / 10**9
return cpu_peak_mem, device_id, gpu_peak_mem


Expand Down
6 changes: 3 additions & 3 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
or "gpu_peak_mem" in self.required_metrics
):
metrics.cpu_peak_mem, _device_id, metrics.gpu_peak_mem = (
self.get_peak_mem(fn)
self.get_peak_mem(fn, self.tb_args.metrics_gpu_backend)
)
if not baseline and "accuracy" in self.required_metrics:
metrics.accuracy = (
Expand Down Expand Up @@ -1014,13 +1014,13 @@ def _init_extra_metrics() -> Dict[str, Any]:
return metrics

def get_peak_mem(
self, fn: Callable
self, fn: Callable, metrics_memory_usage_backend: str
) -> Tuple[Optional[float], Optional[str], Optional[float]]:
return get_peak_memory(
func=fn,
device=self.device,
metrics_needed=["gpu_peak_mem", "cpu_peak_mem"],
metrics_gpu_backend="nvml",
metrics_gpu_backend=metrics_memory_usage_backend,
)

def nsys_rep(self, input_id: int, fn_name: str) -> str:
Expand Down
12 changes: 12 additions & 0 deletions userbenchmark/triton/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def get_parser(args=None):
default=None,
help="Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.",
)
parser.add_argument(
"--metrics-gpu-backend",
choices=["torch", "nvml"],
default="torch",
help=(
"Specify the backend [torch, nvml] to collect metrics. In all modes, the latency "
"(execution time) is always collected using `time.time_ns()`. The CPU peak memory "
"usage is collected by `psutil.Process()`. In nvml mode, the GPU peak memory usage "
"is collected by the `nvml` library. In torch mode, the GPU peak memory usage is "
"collected by `torch.cuda.max_memory_allocated()`."
),
)
parser.add_argument(
"--only",
default=None,
Expand Down

0 comments on commit c396191

Please sign in to comment.