Skip to content

Commit

Permalink
add gpu memory logging support
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed May 17, 2024
1 parent 697bbca commit 2e092a1
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
72 changes: 70 additions & 2 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import datasets
import pandas as pd
import yaml
import torch

"""
This benchmarking script
Expand Down Expand Up @@ -72,6 +73,11 @@
"torch.distributed.elastic.multiprocessing.errors.ChildFailedError"
]

FILE_MEM = "gpu_memory_logs.csv"
GPU_LOG_USED_MEM_COLUMN_NAME = "memory.used [MiB]"
GPU_LOG_METRIC_SUFFIX = " MiB"
GPU_TABLE = "timestamp,name,index,memory.used"
REPORT_GPU_FIELD_NAME = "gpu_mem"

def get_hf_arguments_with_no_value(dataclass_types):
"""this function will return a map (str, bool) of true/false arguments.
Expand Down Expand Up @@ -292,8 +298,10 @@ def __init__(
self.stderr_filename = os.path.join(self.save_dir, FILE_STDERR)
self.command_filename = os.path.join(self.save_dir, FILE_SHELL_COMMAND)
self.results_filename = os.path.join(self.save_dir, FILE_RESULTS)
self.gpu_log_filename = os.path.join(self.save_dir, FILE_MEM)

def run(self, run_cmd: str, environment_variables: Dict = None):

def run(self, run_cmd: str, environment_variables: Dict = None, log_memory:bool=False, memory_log_interval_secs:int=1):

# form the command line
commands = []
Expand All @@ -308,6 +316,38 @@ def run(self, run_cmd: str, environment_variables: Dict = None):
self.environment = environment_variables
self.experiment_args_str = commands
os.makedirs(self.save_dir, exist_ok=True)

if log_memory:
'''
Opens a parallel process to log the device memory of the main experiment process.
- Logs memory at intervals to a csv file in `self.save_dir`
- Terminates at the end of experiment
- GPU log is read and aggregated when the experiment ends & results are saved in Experiment.write_result,
NOTE: assumes the experiment is the only process on the gpu devices -
there are no other processes running on the device in parallel.
Can log more details from nvidia-smi by expanding GPU_Table argument
e.g. "timestamp,name,index,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used"
'''
assert torch.cuda.device_count()>0, "No device detected for memory logging!"
nvidia_logging_cmd = [
"nvidia-smi",
"--query-gpu",
GPU_TABLE,
"--format",
"csv",
"--id",
str(environment_variables['CUDA_VISIBLE_DEVICES']),
"--loop",
str(memory_log_interval_secs),
]
memory_process = subprocess.Popen(
nvidia_logging_cmd,
stdout=open(self.gpu_log_filename, "w"),
text=True,
)

subprocess.run(
self.shell_command,
capture_output=False,
Expand All @@ -317,6 +357,8 @@ def run(self, run_cmd: str, environment_variables: Dict = None):
env={**os.environ.copy(), **environment_variables},
)

if log_memory: memory_process.terminate()

def get_experiment_final_metrics(
self, final_metrics_keys: List[str] = ["train_loss", "train_runtime"]
):
Expand Down Expand Up @@ -374,12 +416,29 @@ def maybe_get_experiment_error_traceback(self):

return None if len(results) == 0 else results

def get_avg_mem_usage_per_sec_by_device_id(self, min_measurement_in_mib=0):
'''
This function retrieves the gpu memory logs and returns the average memory consumed per device across the experiment
Returns a pd.Series of avg mem usage per sec in MiB for each device id
'''
grouped_indices = pd.read_csv(self.gpu_log_filename, skipinitialspace=True).groupby(by='index')
# Calculate the average memory consumption per sec in each device
mem_usage_by_device_id = grouped_indices.apply(
lambda x: x[GPU_LOG_USED_MEM_COLUMN_NAME].str.replace(GPU_LOG_METRIC_SUFFIX, '').astype(float)
)
# filter only mem measurements that are above a min MiB value before taking the average mem usage per device
mem_usage_by_device_id = mem_usage_by_device_id[mem_usage_by_device_id>min_measurement_in_mib]
# reduce to a series with index as GPU ID and the corresponding avg mem for each device
mem_usage_by_device_id = mem_usage_by_device_id.groupby(by='index').mean().squeeze()
return mem_usage_by_device_id

def write_result(self):
"Function to write a json result file"

# save some basic args
save_result = ConfigUtils.convert_args_to_dict(self.experiment_args_str)
save_result["num_gpus"] = self.num_gpus
save_result['num_gpus'] = self.num_gpus
save_result[REPORT_GPU_FIELD_NAME] = self.get_avg_mem_usage_per_sec_by_device_id().mean()

# if there is an error we save the error message else we save the final result
maybe_error_messages = self.maybe_get_experiment_error_traceback()
Expand Down Expand Up @@ -578,6 +637,9 @@ def compress(df):

def main(args):

if args.dry_run and args.log_memory:
setattr(args, "log_memory", False)

# 1. Prepares a standard BenchmarkDataset
# TODO: consider caching the json file
if not args.no_data_processing:
Expand Down Expand Up @@ -615,6 +677,7 @@ def main(args):
experiment.run(
f"{prefix} {FMS_TRAINER}",
environment_variables={"CUDA_VISIBLE_DEVICES": device_ids},
log_memory = args.log_memory,
)

# write results and store pointers to files
Expand Down Expand Up @@ -746,5 +809,10 @@ def main(args):
help="ensures 'model_name_or_paths 'specified in scenarios.yaml work. "
"Useful to check model paths specified correctly before lengthly benchmark runs.",
)
parser.add_argument(
"--log_memory", action='store_true',
help="ensures 'model_name_or_paths 'specified in scenarios.yaml work. "
"Useful to check model paths specified correctly before lengthly benchmark runs."
)
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion scripts/benchmarks/scenarios.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ scenarios:
model_name_or_path:
- 'TheBloke/Mistral-7B-v0.1-GPTQ'
- 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
- TheBloke/Nous-Hermes-Llama2-70B-GPTQ
- 'TheBloke/Llama-2-70B-GPTQ'

0 comments on commit 2e092a1

Please sign in to comment.