Skip to content

Commit

Permalink
Provide Memory Benchmarking Feature to Benchmarking Code (#14)
Browse files Browse the repository at this point in the history
* add gpu memory logging support

* made improvements to GPU reference and result collation

* Renamed memory logging argument to reflect its readings as reserved me
mory using nvidia-smi and changed aggregation function in result collation

* variable renames

* manual linting

* added memory logging functionality via HFTrainer

* added support to benchmark memory using HFTrainer and updated READMEwith explanation of the 2 memory benchmarking options

* addressed changes requested in PR #14

* fix bug and smplify gpu logs aggregation logic

* fixes to calculation of HFTrainer Mem Logging values

* fix calculations

* more fixes

* fix to ignore including  stage inside max calculation of alloc memory

* more comments and README updates

* added fix to keyerror due to empty output dict from OOM

* manual linting

* added benchmark results to refs

* remove unnecessary columns in results gathering

* made changes to results gathering
  • Loading branch information
achew010 authored May 27, 2024
1 parent 2003a3e commit f1895b7
Show file tree
Hide file tree
Showing 6 changed files with 398 additions and 60 deletions.
92 changes: 92 additions & 0 deletions scripts/benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,95 @@ python benchmark.py --help

Note:
- in `run_benchmarks.sh` we will clear the `RESULT_DIR` if it exists, to avoid contaimination with old results. To protect against overwrite, then always run with `NO_OVERWRITE=true`.

## Logging GPU Memory

There are 2 ways to benchmark memory in `run_benchmarks.sh`:
- Setting the environment variable `MEMORY_LOGGING=nvidia` will use Nvidia `nvidia-smi`'s API
- Setting the environment variable `MEMORY_LOGGING=huggingface` (default) will use HuggingFace `HFTrainer`'s API

Both approaches will print out the memory values to the benchmark report.
- For Nvidia, the result column will be `nvidia_mem_reserved`
- For Torch/HF, the result column will be `peak_torch_mem_alloc_in_bytes` and `torch_mem_alloc_in_bytes`

### Nvidia-SMI `nvidia-smi`
`nvidia-smi` is a command line utility (CLI) based on the Nvidia Manage Library (NVML)`. A separate process call is used to start, log and finally terminate the CLI for every experiment.

The keyword `memory.used` is passed to `--query-gpu` argument to log the memory usage at some interval. The list of keywords that can be logged can be referenced from `nvidia-smi --help-query-gpu`

Since it runs on a separate process, it is less likely to affect the training. However, it is a coarser approach than HF as NVML's definition of used memory takes the sum of (memory allocated + memory reserved). Refer to their [documentation](https://docs.nvidia.com/deploy/nvml-api/structnvmlMemory__t.html#structnvmlMemory__t:~:text=Sum%20of%20Reserved%20and%20Allocated%20device%20memory%20(in%20bytes).%20Note%20that%20the%20driver/GPU%20always%20sets%20aside%20a%20small%20amount%20of%20memory%20for%20bookkeeping) here.

After every experiment,
- the logged values are calibrated to remove any existing foreign memory values
- the peak values for each gpu device are taken
- the values are finally averaged across all devices.

### Torch/HuggingFace `HFTrainer`
HFTrainer has a feature to log memory through the `skip_memory_metrics=False` training argument. In their [documentation](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments.skip_memory_metrics), it is mentioned that setting this argument to `False` will affect training speed. In our tests so far (below), we do not see significant difference in throughput (tokens/sec) when using this argument.

The HFTrainer API is more granular than `nvidia-smi` as it uses `torch.cuda` to pinpoint memory usage inside the trainer
- It reports the allocated memory by calling `torch.cuda.memory_allocated()` and `torch.cuda.max_memory_allocated()` inside its probes
- It has memory logging probes at different stages of the Trainer - `init`, `train`, `evaluate`, `predict`

##### NOTE:
- When in distributed mode, the Trainer will only log the rank 0 memory.
- For stability purposes, it only tracks the outer level of train, evaluate and predict methods. i.e. if eval is called during train, there won't be a nested invocation of the memory probe.
- Any GPU memory incurred outside of the defined Trainer stages won't be tracked.

### Additional Details

#### Calculating Memory from HFTrainer Output Metrics

This is an example of the memory values that HFTrainer will produce in the outputs of `train()`
```
output_metrics = {
'train_runtime': 191.2491,
'train_samples_per_second': 0.209,
'train_steps_per_second': 0.052,
'train_tokens_per_second': 428.342,
'train_loss': 1.0627506256103516,
'init_mem_cpu_alloc_delta': 4096,
'init_mem_gpu_alloc_delta': 0,
'init_mem_cpu_peaked_delta': 0,
'init_mem_gpu_peaked_delta': 0,
'train_mem_cpu_alloc_delta': 839086080,
'train_mem_gpu_alloc_delta': -17491768832,
'train_mem_cpu_peaked_delta': 0,
'train_mem_gpu_peaked_delta': 26747825664,
'before_init_mem_cpu': 5513297920,
'before_init_mem_gpu': 36141687296,
'epoch': 0.01
}
```
We refer to the keys of the memory metrics in this order
- `before_init_mem_X` as stage0
- `init_mem_X` as stage1
- `train_mem_X` as stage2
- ...
We currently compute the memory values in the report by taking the largest of sums. For example:
For allocated memory value
```
max([
stage0_mem + stage1_allocated_delta,
stage0_mem + stage1_allocated_delta + stage2_allocated_delta,
...
])
```
For peak memory value
```
max([
stage0_mem + stage1_allocated_delta + stage1_peaked_delta,
stage0_mem + stage1_allocated_delta + stage2_allocated_delta + stage2_peaked_delta,
...
])
```
Notice that we do not include `stage0_mem` alone when computing the max value. This is to avoid misleading comparisons between GPTQ-LoRA and others. GPTQ-LoRA + FSDP currently does not support low-memory mode as mentioned [here](https://github.com/foundation-model-stack/fms-acceleration/issues/18). The `stage0_mem` value of GPTQ-LoRA + FSDP will reflect a larger than expected value as it is loaded fully before the trainer is initialized and then subsequently will be sharded internally in `trainer.prepare`. This might cause some misleading comparisons when other variants are loaded in low-memory mode and have smaller `stage0_mem` memory consumption than GPTQ-LoRA + FSDP. Once low-memory mode is supported for GPTQ-LoRA, we will include `stage0_mem` back inside the max computation
We compare memory values between Nvidia-SMI and Torch in this PR - [Memory Benchmarking](https://github.com/foundation-model-stack/fms-acceleration/pull/14).
205 changes: 203 additions & 2 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers import AutoConfig, HfArgumentParser, TrainingArguments
import datasets
import pandas as pd
import torch
import yaml

"""
Expand Down Expand Up @@ -72,6 +73,68 @@
"torch.distributed.elastic.multiprocessing.errors.ChildFailedError"
]

FILE_MEM = "gpu_memory_logs.csv"
GPU_LOG_USED_MEM_COLUMN_NAME = "memory.used [MiB]"
GPU_LOG_METRIC_SUFFIX = " MiB"
GPU_TABLE = "timestamp,name,index,memory.used"
RESULT_FIELD_RESERVED_GPU_MEM = "nvidia_mem_reserved"
RESULT_FIELD_DEVICE_NAME = "gpu_device_name"

HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT = "before_init_mem_gpu"
HF_TRAINER_LOG_GPU_STAGE_INIT = "init_mem_gpu"
HF_TRAINER_LOG_GPU_STAGE_TRAIN = "train_mem_gpu"
KEYWORD_PEAKED_DELTA = "peaked_delta"
KEYWORD_ALLOC_DELTA = "alloc_delta"
HF_ARG_SKIP_MEMORY_METRIC = "--skip_memory_metrics"
RESULT_FIELD_ALLOCATED_GPU_MEM = "torch_mem_alloc_in_bytes"
RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "peak_torch_mem_alloc_in_bytes"


def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]:
"""
This function computes the gpu summary metrics from the output metrics of Trainer
when `skip_memory_metrics` is set to `False` in transformers.TrainingArguments
This function is called only when `--skip_memory_metrics` exist in the experiment arg
and is set to False. The memory key values are expected to be inside output_metrics. If
output_metrics is empty, return peak=0 and usage=0
Returns
- gpu_peak value in Bytes
- gpu_usage value in Bytes
"""
# Assumes train stage is always called
# this is a tuple of stage names, and a bool to say if it should be included in the summarized number
# we exclude the model loading stages for now, due to
# https://github.com/foundation-model-stack/fms-acceleration/issues/18
# we will renable the loading stages later on once this issue is addressed
if len(output_metrics.keys()) < 1:
return 0, 0

trainer_stage_order = [
(HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, False),
(HF_TRAINER_LOG_GPU_STAGE_INIT, False),
(HF_TRAINER_LOG_GPU_STAGE_TRAIN, True),
]
alloc_running_sum = 0
list_of_alloc_running_sums = []
list_of_peak_running_sums = []
for STAGE_NAME, include in trainer_stage_order:
delta_key = f"{STAGE_NAME}_{KEYWORD_ALLOC_DELTA}"
alloc_running_sum += (
output_metrics[delta_key]
if delta_key in output_metrics
else output_metrics[STAGE_NAME]
)
peak_delta = output_metrics.get(f"{STAGE_NAME}_{KEYWORD_PEAKED_DELTA}", 0)
if include:
list_of_alloc_running_sums.append(alloc_running_sum)
list_of_peak_running_sums.append(alloc_running_sum + peak_delta)

max_alloc_running_sum = max(list_of_alloc_running_sums)
max_peak_running_sum = max(list_of_peak_running_sums)
return max_peak_running_sum, max_alloc_running_sum


def get_hf_arguments_with_no_value(dataclass_types):
"""this function will return a map (str, bool) of true/false arguments.
Expand Down Expand Up @@ -292,8 +355,15 @@ def __init__(
self.stderr_filename = os.path.join(self.save_dir, FILE_STDERR)
self.command_filename = os.path.join(self.save_dir, FILE_SHELL_COMMAND)
self.results_filename = os.path.join(self.save_dir, FILE_RESULTS)
self.gpu_log_filename = os.path.join(self.save_dir, FILE_MEM)

def run(self, run_cmd: str, environment_variables: Dict = None):
def run(
self,
run_cmd: str,
environment_variables: Dict = None,
log_nvidia_smi: bool = False,
memory_log_interval_secs: int = 1,
):

# form the command line
commands = []
Expand All @@ -308,6 +378,39 @@ def run(self, run_cmd: str, environment_variables: Dict = None):
self.environment = environment_variables
self.experiment_args_str = commands
os.makedirs(self.save_dir, exist_ok=True)

if log_nvidia_smi:
"""
Opens a parallel process to log the device memory of the main experiment process.
- Logs memory at intervals to a csv file in `self.save_dir`
- Terminates at the end of experiment
- GPU log is read and aggregated when the experiment ends & results are saved in Experiment.write_result,
NOTE: This feature assumes the following
1. Experiment is the only process on the gpu devices -
there are no other processes running on the device in parallel.
Can log more info from nvidia-smi by expanding GPU_Table argument
e.g. "timestamp,name,index,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used"
Use `nvidia-smi --help-query-gpu` for more reference
"""
nvidia_logging_cmd = [
"nvidia-smi",
"--query-gpu",
GPU_TABLE,
"--format",
"csv",
"--id",
str(environment_variables["CUDA_VISIBLE_DEVICES"]),
"--loop",
str(memory_log_interval_secs),
]
memory_process = subprocess.Popen(
nvidia_logging_cmd,
stdout=open(self.gpu_log_filename, "w"),
text=True,
)

subprocess.run(
self.shell_command,
capture_output=False,
Expand All @@ -317,6 +420,9 @@ def run(self, run_cmd: str, environment_variables: Dict = None):
env={**os.environ.copy(), **environment_variables},
)

if log_nvidia_smi:
memory_process.terminate()

def get_experiment_final_metrics(
self, final_metrics_keys: List[str] = ["train_loss", "train_runtime"]
):
Expand Down Expand Up @@ -374,13 +480,69 @@ def maybe_get_experiment_error_traceback(self):

return None if len(results) == 0 else results

def get_peak_mem_usage_by_device_id(self):
"""
This function retrieves the raw measurements of reserved GPU memory per device across the experiment -
computing the peak value for each gpu and then performing a simple calibration (subtracts peak values by the first reading).
Returns:
- pd.Series of peak memory usage per device id
- the device name as string - e.g. "NVIDIA A100-SXM4-80GB"
Example: For 2 devices with GPU Indices 0,1 - it will return the max measurement value (in MiB) of each device as a Series:
- pd.Series
index
0 52729.0
1 52783.0
Name: memory.used [MiB], dtype: float64
"""

# group the gpu readings into device ids
gpu_logs = pd.read_csv(self.gpu_log_filename, skipinitialspace=True)
# assume that all the devices have the same device name
device_name = gpu_logs.name.iloc[-1]
# extract and convert the gpu memory usage as float values
gpu_logs[GPU_LOG_USED_MEM_COLUMN_NAME] = gpu_logs[
GPU_LOG_USED_MEM_COLUMN_NAME
].apply(lambda x: float(x.replace(GPU_LOG_METRIC_SUFFIX, "")))
mem_usage_by_device_id = gpu_logs.groupby("index")[GPU_LOG_USED_MEM_COLUMN_NAME]
# Calibrate values by subtracting out the initial values of the GPU readings
# to ensure no existing memory is counted in addition with the experiment
initial_values = mem_usage_by_device_id.first()
peak_values = mem_usage_by_device_id.max()
return peak_values.sub(initial_values), device_name

def write_result(self):
"Function to write a json result file"

# save some basic args
save_result = ConfigUtils.convert_args_to_dict(self.experiment_args_str)
save_result["num_gpus"] = self.num_gpus

# if a gpu log file exist, process the raw nvidia logs and write to result
if os.path.isfile(self.gpu_log_filename):
# Add GPU info and measurements into the result saving
peak_mem_usage_by_device_id, device_name = (
self.get_peak_mem_usage_by_device_id()
)
save_result[RESULT_FIELD_DEVICE_NAME] = device_name
# Memory usage is averaged across all devices in the final result
save_result[RESULT_FIELD_RESERVED_GPU_MEM] = (
peak_mem_usage_by_device_id.mean()
)

# process gpu mem from output metrics and write to result
# check if HF_ARG_SKIP_MEMORY_METRIC is set to False in experiment arg
# this arg is specified explicitly inside `def generate_list_of_experiments``
argument_idx = self.experiment_arg.index(HF_ARG_SKIP_MEMORY_METRIC)
write_memory_metric = not self.experiment_arg[argument_idx + 1]
if write_memory_metric:
peak_gpu_mem, gpu_allocated_mem = extract_gpu_memory_metrics(
self.get_experiment_final_metrics()
)
save_result[RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM] = peak_gpu_mem
save_result[RESULT_FIELD_ALLOCATED_GPU_MEM] = gpu_allocated_mem

# if there is an error we save the error message else we save the final result
maybe_error_messages = self.maybe_get_experiment_error_traceback()
if maybe_error_messages is None:
Expand Down Expand Up @@ -493,6 +655,7 @@ def generate_list_of_experiments(
output_dir: str = "results",
hf_products_dir: str = "hf",
dry_run: bool = False,
log_memory_in_trainer: bool = False,
) -> List[Experiment]:
"""Construct list of experiments to be run. Takes in default_config and
any matrices in scenario and experiment_config
Expand All @@ -504,6 +667,8 @@ def generate_list_of_experiments(
expr_arg_w_outputdir = exp_arg + [
"--output_dir",
os.path.join(experiment_output_dir, hf_products_dir),
HF_ARG_SKIP_MEMORY_METRIC,
not log_memory_in_trainer,
]
expr_cls = Experiment if not dry_run else DryRunExperiment
_expr = expr_cls(
Expand Down Expand Up @@ -578,6 +743,19 @@ def compress(df):

def main(args):

# Gathers available gpu device ids that will be used for benchmarking.
# If "CUDA_VISIBLE_DEVICES" is specified, it will return the specified device ids
# if no gpu ids are specified, it will default to the enumeration of available ids
assert torch.cuda.device_count() > 0, "No device detected for memory logging!"
available_gpus_indices = os.environ.get("CUDA_VISIBLE_DEVICES")
if available_gpus_indices:
available_gpus_indices = available_gpus_indices.split(",")
else:
available_gpus_indices = [str(i) for i in range(torch.cuda.device_count())]

if args.dry_run and args.log_nvidia_smi:
args.log_nvidia_smi = False

# 1. Prepares a standard BenchmarkDataset
# TODO: consider caching the json file
if not args.no_data_processing:
Expand All @@ -600,6 +778,7 @@ def main(args):
experiment_args,
output_dir=args.results_output_path,
dry_run=args.dry_run,
log_memory_in_trainer=args.log_memory_hf,
)
):
if experiment.num_gpus > 1:
Expand All @@ -611,10 +790,20 @@ def main(args):
else:
prefix = COMMAND_PYTHON

device_ids = ",".join([str(i) for i in range(experiment.num_gpus)])
assert experiment.num_gpus <= len(
available_gpus_indices
), "Experiment requires more gpus than is available on the platform."
"""
Experiment will take only the ids from the available gpu indices,
this ensures that whatever GPUs are exposed to benchmark.py are the only
devices that each experiment can have access to.
"""
device_ids = ",".join(available_gpus_indices[: experiment.num_gpus])

experiment.run(
f"{prefix} {FMS_TRAINER}",
environment_variables={"CUDA_VISIBLE_DEVICES": device_ids},
log_nvidia_smi=args.log_nvidia_smi,
)

# write results and store pointers to files
Expand Down Expand Up @@ -746,5 +935,17 @@ def main(args):
help="ensures 'model_name_or_paths 'specified in scenarios.yaml work. "
"Useful to check model paths specified correctly before lengthly benchmark runs.",
)
parser.add_argument(
"--log_nvidia_smi",
action="store_true",
help="Use `nvidia-smi` API to log reserved memory of benchmarks",
)

parser.add_argument(
"--log_memory_hf",
action="store_true",
help="Uses memory logging from HF Trainer Arguments API to log gpu memory, for distributed runs only rank 0 is measured",
)

args = parser.parse_args()
main(args)
Loading

0 comments on commit f1895b7

Please sign in to comment.