diff --git a/scripts/benchmarks/README.md b/scripts/benchmarks/README.md index c4be3268..115719b7 100644 --- a/scripts/benchmarks/README.md +++ b/scripts/benchmarks/README.md @@ -92,3 +92,95 @@ python benchmark.py --help Note: - in `run_benchmarks.sh` we will clear the `RESULT_DIR` if it exists, to avoid contaimination with old results. To protect against overwrite, then always run with `NO_OVERWRITE=true`. + +## Logging GPU Memory + +There are 2 ways to benchmark memory in `run_benchmarks.sh`: +- Setting the environment variable `MEMORY_LOGGING=nvidia` will use Nvidia `nvidia-smi`'s API +- Setting the environment variable `MEMORY_LOGGING=huggingface` (default) will use HuggingFace `HFTrainer`'s API + +Both approaches will print out the memory values to the benchmark report. + - For Nvidia, the result column will be `nvidia_mem_reserved` + - For Torch/HF, the result column will be `peak_torch_mem_alloc_in_bytes` and `torch_mem_alloc_in_bytes` + +### Nvidia-SMI `nvidia-smi` +`nvidia-smi` is a command line utility (CLI) based on the Nvidia Manage Library (NVML)`. A separate process call is used to start, log and finally terminate the CLI for every experiment. + +The keyword `memory.used` is passed to `--query-gpu` argument to log the memory usage at some interval. The list of keywords that can be logged can be referenced from `nvidia-smi --help-query-gpu` + +Since it runs on a separate process, it is less likely to affect the training. However, it is a coarser approach than HF as NVML's definition of used memory takes the sum of (memory allocated + memory reserved). Refer to their [documentation](https://docs.nvidia.com/deploy/nvml-api/structnvmlMemory__t.html#structnvmlMemory__t:~:text=Sum%20of%20Reserved%20and%20Allocated%20device%20memory%20(in%20bytes).%20Note%20that%20the%20driver/GPU%20always%20sets%20aside%20a%20small%20amount%20of%20memory%20for%20bookkeeping) here. + +After every experiment, + - the logged values are calibrated to remove any existing foreign memory values + - the peak values for each gpu device are taken + - the values are finally averaged across all devices. + +### Torch/HuggingFace `HFTrainer` +HFTrainer has a feature to log memory through the `skip_memory_metrics=False` training argument. In their [documentation](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments.skip_memory_metrics), it is mentioned that setting this argument to `False` will affect training speed. In our tests so far (below), we do not see significant difference in throughput (tokens/sec) when using this argument. + +The HFTrainer API is more granular than `nvidia-smi` as it uses `torch.cuda` to pinpoint memory usage inside the trainer + - It reports the allocated memory by calling `torch.cuda.memory_allocated()` and `torch.cuda.max_memory_allocated()` inside its probes + - It has memory logging probes at different stages of the Trainer - `init`, `train`, `evaluate`, `predict` + +##### NOTE: +- When in distributed mode, the Trainer will only log the rank 0 memory. +- For stability purposes, it only tracks the outer level of train, evaluate and predict methods. i.e. if eval is called during train, there won't be a nested invocation of the memory probe. +- Any GPU memory incurred outside of the defined Trainer stages won't be tracked. + +### Additional Details + +#### Calculating Memory from HFTrainer Output Metrics + +This is an example of the memory values that HFTrainer will produce in the outputs of `train()` +``` +output_metrics = { + 'train_runtime': 191.2491, + 'train_samples_per_second': 0.209, + 'train_steps_per_second': 0.052, + 'train_tokens_per_second': 428.342, + 'train_loss': 1.0627506256103516, + 'init_mem_cpu_alloc_delta': 4096, + 'init_mem_gpu_alloc_delta': 0, + 'init_mem_cpu_peaked_delta': 0, + 'init_mem_gpu_peaked_delta': 0, + 'train_mem_cpu_alloc_delta': 839086080, + 'train_mem_gpu_alloc_delta': -17491768832, + 'train_mem_cpu_peaked_delta': 0, + 'train_mem_gpu_peaked_delta': 26747825664, + 'before_init_mem_cpu': 5513297920, + 'before_init_mem_gpu': 36141687296, + 'epoch': 0.01 +} +``` + +We refer to the keys of the memory metrics in this order + - `before_init_mem_X` as stage0 + - `init_mem_X` as stage1 + - `train_mem_X` as stage2 + - ... + +We currently compute the memory values in the report by taking the largest of sums. For example: + +For allocated memory value +``` +max([ + stage0_mem + stage1_allocated_delta, + stage0_mem + stage1_allocated_delta + stage2_allocated_delta, + ... +]) +``` + +For peak memory value +``` +max([ + stage0_mem + stage1_allocated_delta + stage1_peaked_delta, + stage0_mem + stage1_allocated_delta + stage2_allocated_delta + stage2_peaked_delta, + ... +]) +``` + +Notice that we do not include `stage0_mem` alone when computing the max value. This is to avoid misleading comparisons between GPTQ-LoRA and others. GPTQ-LoRA + FSDP currently does not support low-memory mode as mentioned [here](https://github.com/foundation-model-stack/fms-acceleration/issues/18). The `stage0_mem` value of GPTQ-LoRA + FSDP will reflect a larger than expected value as it is loaded fully before the trainer is initialized and then subsequently will be sharded internally in `trainer.prepare`. This might cause some misleading comparisons when other variants are loaded in low-memory mode and have smaller `stage0_mem` memory consumption than GPTQ-LoRA + FSDP. Once low-memory mode is supported for GPTQ-LoRA, we will include `stage0_mem` back inside the max computation + +We compare memory values between Nvidia-SMI and Torch in this PR - [Memory Benchmarking](https://github.com/foundation-model-stack/fms-acceleration/pull/14). + + diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 9d05b540..afbf61cf 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -13,6 +13,7 @@ from transformers import AutoConfig, HfArgumentParser, TrainingArguments import datasets import pandas as pd +import torch import yaml """ @@ -72,6 +73,68 @@ "torch.distributed.elastic.multiprocessing.errors.ChildFailedError" ] +FILE_MEM = "gpu_memory_logs.csv" +GPU_LOG_USED_MEM_COLUMN_NAME = "memory.used [MiB]" +GPU_LOG_METRIC_SUFFIX = " MiB" +GPU_TABLE = "timestamp,name,index,memory.used" +RESULT_FIELD_RESERVED_GPU_MEM = "nvidia_mem_reserved" +RESULT_FIELD_DEVICE_NAME = "gpu_device_name" + +HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT = "before_init_mem_gpu" +HF_TRAINER_LOG_GPU_STAGE_INIT = "init_mem_gpu" +HF_TRAINER_LOG_GPU_STAGE_TRAIN = "train_mem_gpu" +KEYWORD_PEAKED_DELTA = "peaked_delta" +KEYWORD_ALLOC_DELTA = "alloc_delta" +HF_ARG_SKIP_MEMORY_METRIC = "--skip_memory_metrics" +RESULT_FIELD_ALLOCATED_GPU_MEM = "torch_mem_alloc_in_bytes" +RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "peak_torch_mem_alloc_in_bytes" + + +def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]: + """ + This function computes the gpu summary metrics from the output metrics of Trainer + when `skip_memory_metrics` is set to `False` in transformers.TrainingArguments + + This function is called only when `--skip_memory_metrics` exist in the experiment arg + and is set to False. The memory key values are expected to be inside output_metrics. If + output_metrics is empty, return peak=0 and usage=0 + + Returns + - gpu_peak value in Bytes + - gpu_usage value in Bytes + """ + # Assumes train stage is always called + # this is a tuple of stage names, and a bool to say if it should be included in the summarized number + # we exclude the model loading stages for now, due to + # https://github.com/foundation-model-stack/fms-acceleration/issues/18 + # we will renable the loading stages later on once this issue is addressed + if len(output_metrics.keys()) < 1: + return 0, 0 + + trainer_stage_order = [ + (HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, False), + (HF_TRAINER_LOG_GPU_STAGE_INIT, False), + (HF_TRAINER_LOG_GPU_STAGE_TRAIN, True), + ] + alloc_running_sum = 0 + list_of_alloc_running_sums = [] + list_of_peak_running_sums = [] + for STAGE_NAME, include in trainer_stage_order: + delta_key = f"{STAGE_NAME}_{KEYWORD_ALLOC_DELTA}" + alloc_running_sum += ( + output_metrics[delta_key] + if delta_key in output_metrics + else output_metrics[STAGE_NAME] + ) + peak_delta = output_metrics.get(f"{STAGE_NAME}_{KEYWORD_PEAKED_DELTA}", 0) + if include: + list_of_alloc_running_sums.append(alloc_running_sum) + list_of_peak_running_sums.append(alloc_running_sum + peak_delta) + + max_alloc_running_sum = max(list_of_alloc_running_sums) + max_peak_running_sum = max(list_of_peak_running_sums) + return max_peak_running_sum, max_alloc_running_sum + def get_hf_arguments_with_no_value(dataclass_types): """this function will return a map (str, bool) of true/false arguments. @@ -292,8 +355,15 @@ def __init__( self.stderr_filename = os.path.join(self.save_dir, FILE_STDERR) self.command_filename = os.path.join(self.save_dir, FILE_SHELL_COMMAND) self.results_filename = os.path.join(self.save_dir, FILE_RESULTS) + self.gpu_log_filename = os.path.join(self.save_dir, FILE_MEM) - def run(self, run_cmd: str, environment_variables: Dict = None): + def run( + self, + run_cmd: str, + environment_variables: Dict = None, + log_nvidia_smi: bool = False, + memory_log_interval_secs: int = 1, + ): # form the command line commands = [] @@ -308,6 +378,39 @@ def run(self, run_cmd: str, environment_variables: Dict = None): self.environment = environment_variables self.experiment_args_str = commands os.makedirs(self.save_dir, exist_ok=True) + + if log_nvidia_smi: + """ + Opens a parallel process to log the device memory of the main experiment process. + - Logs memory at intervals to a csv file in `self.save_dir` + - Terminates at the end of experiment + - GPU log is read and aggregated when the experiment ends & results are saved in Experiment.write_result, + + NOTE: This feature assumes the following + 1. Experiment is the only process on the gpu devices - + there are no other processes running on the device in parallel. + + Can log more info from nvidia-smi by expanding GPU_Table argument + e.g. "timestamp,name,index,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used" + Use `nvidia-smi --help-query-gpu` for more reference + """ + nvidia_logging_cmd = [ + "nvidia-smi", + "--query-gpu", + GPU_TABLE, + "--format", + "csv", + "--id", + str(environment_variables["CUDA_VISIBLE_DEVICES"]), + "--loop", + str(memory_log_interval_secs), + ] + memory_process = subprocess.Popen( + nvidia_logging_cmd, + stdout=open(self.gpu_log_filename, "w"), + text=True, + ) + subprocess.run( self.shell_command, capture_output=False, @@ -317,6 +420,9 @@ def run(self, run_cmd: str, environment_variables: Dict = None): env={**os.environ.copy(), **environment_variables}, ) + if log_nvidia_smi: + memory_process.terminate() + def get_experiment_final_metrics( self, final_metrics_keys: List[str] = ["train_loss", "train_runtime"] ): @@ -374,6 +480,38 @@ def maybe_get_experiment_error_traceback(self): return None if len(results) == 0 else results + def get_peak_mem_usage_by_device_id(self): + """ + This function retrieves the raw measurements of reserved GPU memory per device across the experiment - + computing the peak value for each gpu and then performing a simple calibration (subtracts peak values by the first reading). + Returns: + - pd.Series of peak memory usage per device id + - the device name as string - e.g. "NVIDIA A100-SXM4-80GB" + + Example: For 2 devices with GPU Indices 0,1 - it will return the max measurement value (in MiB) of each device as a Series: + + - pd.Series + index + 0 52729.0 + 1 52783.0 + Name: memory.used [MiB], dtype: float64 + """ + + # group the gpu readings into device ids + gpu_logs = pd.read_csv(self.gpu_log_filename, skipinitialspace=True) + # assume that all the devices have the same device name + device_name = gpu_logs.name.iloc[-1] + # extract and convert the gpu memory usage as float values + gpu_logs[GPU_LOG_USED_MEM_COLUMN_NAME] = gpu_logs[ + GPU_LOG_USED_MEM_COLUMN_NAME + ].apply(lambda x: float(x.replace(GPU_LOG_METRIC_SUFFIX, ""))) + mem_usage_by_device_id = gpu_logs.groupby("index")[GPU_LOG_USED_MEM_COLUMN_NAME] + # Calibrate values by subtracting out the initial values of the GPU readings + # to ensure no existing memory is counted in addition with the experiment + initial_values = mem_usage_by_device_id.first() + peak_values = mem_usage_by_device_id.max() + return peak_values.sub(initial_values), device_name + def write_result(self): "Function to write a json result file" @@ -381,6 +519,30 @@ def write_result(self): save_result = ConfigUtils.convert_args_to_dict(self.experiment_args_str) save_result["num_gpus"] = self.num_gpus + # if a gpu log file exist, process the raw nvidia logs and write to result + if os.path.isfile(self.gpu_log_filename): + # Add GPU info and measurements into the result saving + peak_mem_usage_by_device_id, device_name = ( + self.get_peak_mem_usage_by_device_id() + ) + save_result[RESULT_FIELD_DEVICE_NAME] = device_name + # Memory usage is averaged across all devices in the final result + save_result[RESULT_FIELD_RESERVED_GPU_MEM] = ( + peak_mem_usage_by_device_id.mean() + ) + + # process gpu mem from output metrics and write to result + # check if HF_ARG_SKIP_MEMORY_METRIC is set to False in experiment arg + # this arg is specified explicitly inside `def generate_list_of_experiments`` + argument_idx = self.experiment_arg.index(HF_ARG_SKIP_MEMORY_METRIC) + write_memory_metric = not self.experiment_arg[argument_idx + 1] + if write_memory_metric: + peak_gpu_mem, gpu_allocated_mem = extract_gpu_memory_metrics( + self.get_experiment_final_metrics() + ) + save_result[RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM] = peak_gpu_mem + save_result[RESULT_FIELD_ALLOCATED_GPU_MEM] = gpu_allocated_mem + # if there is an error we save the error message else we save the final result maybe_error_messages = self.maybe_get_experiment_error_traceback() if maybe_error_messages is None: @@ -493,6 +655,7 @@ def generate_list_of_experiments( output_dir: str = "results", hf_products_dir: str = "hf", dry_run: bool = False, + log_memory_in_trainer: bool = False, ) -> List[Experiment]: """Construct list of experiments to be run. Takes in default_config and any matrices in scenario and experiment_config @@ -504,6 +667,8 @@ def generate_list_of_experiments( expr_arg_w_outputdir = exp_arg + [ "--output_dir", os.path.join(experiment_output_dir, hf_products_dir), + HF_ARG_SKIP_MEMORY_METRIC, + not log_memory_in_trainer, ] expr_cls = Experiment if not dry_run else DryRunExperiment _expr = expr_cls( @@ -578,6 +743,19 @@ def compress(df): def main(args): + # Gathers available gpu device ids that will be used for benchmarking. + # If "CUDA_VISIBLE_DEVICES" is specified, it will return the specified device ids + # if no gpu ids are specified, it will default to the enumeration of available ids + assert torch.cuda.device_count() > 0, "No device detected for memory logging!" + available_gpus_indices = os.environ.get("CUDA_VISIBLE_DEVICES") + if available_gpus_indices: + available_gpus_indices = available_gpus_indices.split(",") + else: + available_gpus_indices = [str(i) for i in range(torch.cuda.device_count())] + + if args.dry_run and args.log_nvidia_smi: + args.log_nvidia_smi = False + # 1. Prepares a standard BenchmarkDataset # TODO: consider caching the json file if not args.no_data_processing: @@ -600,6 +778,7 @@ def main(args): experiment_args, output_dir=args.results_output_path, dry_run=args.dry_run, + log_memory_in_trainer=args.log_memory_hf, ) ): if experiment.num_gpus > 1: @@ -611,10 +790,20 @@ def main(args): else: prefix = COMMAND_PYTHON - device_ids = ",".join([str(i) for i in range(experiment.num_gpus)]) + assert experiment.num_gpus <= len( + available_gpus_indices + ), "Experiment requires more gpus than is available on the platform." + """ + Experiment will take only the ids from the available gpu indices, + this ensures that whatever GPUs are exposed to benchmark.py are the only + devices that each experiment can have access to. + """ + device_ids = ",".join(available_gpus_indices[: experiment.num_gpus]) + experiment.run( f"{prefix} {FMS_TRAINER}", environment_variables={"CUDA_VISIBLE_DEVICES": device_ids}, + log_nvidia_smi=args.log_nvidia_smi, ) # write results and store pointers to files @@ -746,5 +935,17 @@ def main(args): help="ensures 'model_name_or_paths 'specified in scenarios.yaml work. " "Useful to check model paths specified correctly before lengthly benchmark runs.", ) + parser.add_argument( + "--log_nvidia_smi", + action="store_true", + help="Use `nvidia-smi` API to log reserved memory of benchmarks", + ) + + parser.add_argument( + "--log_memory_hf", + action="store_true", + help="Uses memory logging from HF Trainer Arguments API to log gpu memory, for distributed runs only rank 0 is measured", + ) + args = parser.parse_args() main(args) diff --git a/scripts/benchmarks/display_bench_results.py b/scripts/benchmarks/display_bench_results.py index 30b54e63..b590f26c 100644 --- a/scripts/benchmarks/display_bench_results.py +++ b/scripts/benchmarks/display_bench_results.py @@ -3,13 +3,17 @@ # First Party # import this because of alot of internal contants -from scripts.benchmarks.benchmark import gather_report +from scripts.benchmarks.benchmark import gather_report, DIR_SAMP_CONFIGS +from typing import List - -def main(*directories: str, output_filename: str = "results.csv"): +def main(*directories: str, output_filename: str = "results.csv", remove_columns: List[str] = None): "gather outputs from a list of directories and output to a csv" df, constant = gather_report(*directories, raw=False) + # filter result columns to keep by the inverse of remove_columns + if remove_columns: + df = df[df.columns[~df.columns.isin(remove_columns)]] + errors = [] try: # remove error messages if any @@ -19,9 +23,7 @@ def main(*directories: str, output_filename: str = "results.csv"): except: pass df = df.reset_index().drop("output_dir", axis=1) - df.reindex(sorted(df.columns), axis=1).to_csv( - output_filename, index=False - ) + df.reindex(sorted(df.columns), axis=1).to_csv(output_filename, index=False) print("***************** Report Created ******************") print(f"Total lines: '{len(df)}'") print(f"Number columns included: '{len(df.columns)}'") @@ -46,5 +48,11 @@ def main(*directories: str, output_filename: str = "results.csv"): default="results.csv", help="name of final csv report file.", ) + parser.add_argument( + "--remove_columns", + nargs="*", + help="list of columns to ignore from results.csv", + ) + args = parser.parse_args() - main(args.bench_outputs, output_filename=args.result_file) + main(args.bench_outputs, output_filename=args.result_file, remove_columns=args.remove_columns) diff --git a/scripts/benchmarks/refs/a100_80gb.csv b/scripts/benchmarks/refs/a100_80gb.csv index 93dd0a28..4434d864 100644 --- a/scripts/benchmarks/refs/a100_80gb.csv +++ b/scripts/benchmarks/refs/a100_80gb.csv @@ -1,49 +1,61 @@ -acceleration_framework_config_file,epoch,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,output_dir,peft_method,per_device_train_batch_size,r,target_modules,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second -,0.15,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,4,,,0.8943243026733398,561.4936,0.712,0.178,2917.932 -,0.15,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,2,,,0.8696886157989502,306.2728,1.306,0.327,2674.74 -,0.29,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,8,,,1.0190681648254394,1094.7748,0.731,0.091,2993.127 -,0.29,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,4,,,0.8909366416931153,572.0158,1.399,0.175,2864.256 -,,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,4,,,,,,, -,,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,2,,,,,,, -,,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,8,,,,,,, -,,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,4,,,,,,, -,,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,4,,,,,,, -,,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,2,,,,,,, -,,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,8,,,,,,, -,,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,4,,,,,,, -,0.15,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.8808393669128418,458.0185,0.873,0.218,3577.148 -,0.15,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.8548675441741943,259.6061,1.541,0.385,3155.55 -,0.29,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,1.007005090713501,915.9053,0.873,0.109,3577.662 -,0.29,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.8773036098480225,480.6995,1.664,0.208,3408.367 -,,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,, -,0.15,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.862400369644165,535.3534,0.747,0.187,1530.204 -,,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,, -,0.29,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.8798200416564942,924.5333,0.865,0.108,1772.137 -,,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,, -,,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,, -,,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,, -,,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,, -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.15,True,accelerated-peft-bnb,24,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.8661054801940918,481.8265,0.83,0.208,3400.394 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.15,True,accelerated-peft-bnb,25,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.8560933685302734,271.0715,1.476,0.369,3022.081 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.29,True,accelerated-peft-bnb,26,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,0.8718929100036621,951.8817,0.84,0.105,3442.445 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.29,True,accelerated-peft-bnb,27,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.8511034965515136,498.9262,1.603,0.2,3283.852 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.15,True,accelerated-peft-bnb,28,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.8973640727996827,908.6145,0.44,0.11,1803.185 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.15,True,accelerated-peft-bnb,29,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.8554682540893555,548.0391,0.73,0.182,1494.784 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.29,True,accelerated-peft-bnb,30,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,0.8935444927215577,1714.3117,0.467,0.058,1911.438 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.29,True,accelerated-peft-bnb,31,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.8596937179565429,954.0851,0.838,0.105,1717.247 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.14,True,accelerated-peft-bnb,32,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.000812177658081,3696.2907,0.108,0.027,443.255 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.14,True,accelerated-peft-bnb,33,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9307080173492431,1960.7862,0.204,0.051,417.792 -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,True,accelerated-peft-bnb,34,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,, -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.28,True,accelerated-peft-bnb,35,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9387501430511475,3809.1796,0.21,0.026,430.119 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.15,True,accelerated-peft-autogptq,36,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9700051403045654,478.8299,0.835,0.209,3421.675 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.15,True,accelerated-peft-autogptq,37,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9695001697540283,270.0251,1.481,0.37,3033.792 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.29,True,accelerated-peft-autogptq,38,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,0.9514076042175293,946.5715,0.845,0.106,3461.756 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.29,True,accelerated-peft-autogptq,39,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9824443531036376,496.6611,1.611,0.201,3298.829 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.15,True,accelerated-peft-autogptq,40,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9041421699523926,872.5836,0.458,0.115,1877.643 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.15,True,accelerated-peft-autogptq,41,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9010070323944092,499.3435,0.801,0.2,1640.554 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.29,True,accelerated-peft-autogptq,42,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,0.9001609039306641,1666.1579,0.48,0.06,1966.68 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.29,True,accelerated-peft-autogptq,43,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.8965495491027832,897.4939,0.891,0.111,1825.528 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.14,True,accelerated-peft-autogptq,44,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9533391189575195,3621.8261,0.11,0.028,452.368 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.14,True,accelerated-peft-autogptq,45,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9467405033111572,1886.6815,0.212,0.053,434.202 -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,True,accelerated-peft-autogptq,46,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,, -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,True,accelerated-peft-autogptq,47,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,, +epoch,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,nvidia_mem_reserved,peak_torch_mem_alloc_in_bytes,peft_method,per_device_train_batch_size,r,target_modules,torch_mem_alloc_in_bytes,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second +0.04,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,77705.0,72971724288.0,,4,,,44004763136.0,0.9278398831685384,177.1092,0.678,0.169,2775.237 +0.04,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,44706.0,36762859520.0,,2,,,29521119232.0,0.8970902442932129,91.086,1.317,0.329,2698.11 +0.09,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,74383.0,72972117504.0,,8,,,44005156352.0,0.9879656155904134,322.458,0.744,0.093,3048.583 +0.09,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,53907.0,36763056128.0,,4,,,29521315840.0,0.9259945551554362,167.7727,1.431,0.179,2929.678 +,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,4,,,,,,,, +,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79353.0,,,2,,,,,,,, +,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,8,,,,,,,, +,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79827.0,,,4,,,,,,,, +,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,4,,,,,,,, +,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,80830.0,,,2,,,,,,,, +,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,8,,,,,,,, +,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,80834.5,,,4,,,,,,,, +0.04,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,29731.0,26108963328.0,lora,4,16,q_proj k_proj v_proj o_proj,15119590912.0,0.9096682230631511,136.624,0.878,0.22,3597.611 +0.04,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,18697.0,15123161088.0,lora,2,16,q_proj k_proj v_proj o_proj,7850391552.0,0.8918854713439941,82.0311,1.463,0.366,2995.936 +0.09,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,43195.0,37098695168.0,lora,8,16,q_proj k_proj v_proj o_proj,15119984128.0,0.962119706471761,270.6301,0.887,0.111,3632.412 +0.09,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,26235.0,21433753600.0,lora,4,16,q_proj k_proj v_proj o_proj,7850588160.0,0.9218235015869141,143.8184,1.669,0.209,3417.643 +,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, +0.04,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,62617.0,57540387840.0,lora,2,16,q_proj k_proj v_proj o_proj,47311452160.0,0.9361546834309896,179.3128,0.669,0.167,1370.566 +,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, +0.09,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,69848.0,64347637760.0,lora,4,16,q_proj k_proj v_proj o_proj,47311648768.0,0.9383139928181966,280.8919,0.854,0.107,1749.855 +,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, +,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80894.0,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,, +,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, +,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80979.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, +0.04,True,baseline-peft-bnb,24,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,27023.0,22825932800.0,lora,4,16,q_proj k_proj v_proj o_proj,5368221184.0,0.9589527130126954,178.8061,0.671,0.168,2748.9 +0.04,True,baseline-peft-bnb,25,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13530.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9154380798339844,87.3652,1.374,0.343,2813.02 +0.09,True,baseline-peft-bnb,26,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,47145.0,40278956032.0,lora,8,16,q_proj k_proj v_proj o_proj,5368614400.0,0.9702634493509928,341.2286,0.703,0.088,2880.884 +0.09,True,baseline-peft-bnb,27,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21502.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.914565912882487,149.9341,1.601,0.2,3278.241 +0.04,True,baseline-peft-bnb,28,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,48313.0,46419968512.0,lora,4,16,q_proj k_proj v_proj o_proj,25726225920.0,0.9744932492574055,351.8623,0.341,0.085,1396.91 +0.04,True,baseline-peft-bnb,29,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25549.0,21922782720.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9303209940592448,171.4299,0.7,0.175,1433.589 +0.09,True,baseline-peft-bnb,30,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,69931.0,67089150464.0,lora,8,16,q_proj k_proj v_proj o_proj,25726619136.0,0.9745417594909668,629.837,0.381,0.048,1560.785 +0.09,True,baseline-peft-bnb,31,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29384115200.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9310146331787109,300.5119,0.799,0.1,1635.609 +,True,baseline-peft-bnb,32,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80893.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, +0.04,True,baseline-peft-bnb,33,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52634.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.0399916648864747,584.3145,0.205,0.051,420.595 +,True,baseline-peft-bnb,34,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,79557.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, +,True,baseline-peft-bnb,35,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80749.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, +0.04,True,accelerated-peft-bnb,36,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,19931.0,15860019712.0,lora,4,16,q_proj k_proj v_proj o_proj,4843384320.0,0.9652111371358235,143.3569,0.837,0.209,3428.645 +0.04,True,accelerated-peft-bnb,37,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13497.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9277165730794271,86.4307,1.388,0.347,2843.435 +0.09,True,accelerated-peft-bnb,38,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,34355.0,26849751552.0,lora,8,16,q_proj k_proj v_proj o_proj,4843777536.0,0.9493892669677735,279.7156,0.858,0.107,3514.427 +0.09,True,accelerated-peft-bnb,39,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21479.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.9110882759094239,149.3914,1.607,0.201,3290.15 +0.04,True,accelerated-peft-bnb,40,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,38405.0,36218024448.0,lora,4,16,q_proj k_proj v_proj o_proj,25201389056.0,0.9741149584452311,278.5888,0.431,0.108,1764.32 +0.04,True,accelerated-peft-bnb,41,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25592.0,21906697728.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9300654411315918,172.7359,0.695,0.174,1422.75 +0.09,True,accelerated-peft-bnb,42,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,50875.0,47207756288.0,lora,8,16,q_proj k_proj v_proj o_proj,25201782272.0,0.9748441060384114,512.2298,0.469,0.059,1919.139 +0.09,True,accelerated-peft-bnb,43,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29369087488.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9301350593566895,287.6381,0.834,0.104,1708.814 +0.04,True,accelerated-peft-bnb,44,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,72829.0,68159977472.0,lora,4,16,q_proj k_proj v_proj o_proj,37346815488.0,1.118430455525716,1075.2044,0.112,0.028,457.141 +0.04,True,accelerated-peft-bnb,45,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52632.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.040946865081787,586.651,0.205,0.051,418.92 +,True,accelerated-peft-bnb,46,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80405.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, +,True,accelerated-peft-bnb,47,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80954.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, +0.04,True,accelerated-peft-autogptq,48,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,20453.0,15890329088.0,lora,4,16,q_proj k_proj v_proj o_proj,4873693696.0,1.3805528958638509,151.0359,0.795,0.199,3254.326 +0.04,True,accelerated-peft-autogptq,49,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,17198.0,9952175616.0,lora,2,16,q_proj k_proj v_proj o_proj,3005709312.0,1.1706618309020995,87.4109,1.373,0.343,2811.548 +0.09,True,accelerated-peft-autogptq,50,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,34247.0,26880060928.0,lora,8,16,q_proj k_proj v_proj o_proj,4874086912.0,1.2741642634073893,282.6391,0.849,0.106,3478.076 +0.09,True,accelerated-peft-autogptq,51,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,24783.0,16262768128.0,lora,4,16,q_proj k_proj v_proj o_proj,3005905920.0,1.043952751159668,152.5473,1.573,0.197,3222.083 +0.04,True,accelerated-peft-autogptq,52,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,37461.0,35528093184.0,lora,4,16,q_proj k_proj v_proj o_proj,24511457792.0,0.9936613400777181,263.6066,0.455,0.114,1864.597 +0.04,True,accelerated-peft-autogptq,53,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,46641.0,25708175360.0,lora,2,16,q_proj k_proj v_proj o_proj,12788874240.0,0.9420519828796386,167.065,0.718,0.18,1471.045 +0.09,True,accelerated-peft-autogptq,54,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,49925.0,46517825024.0,lora,8,16,q_proj k_proj v_proj o_proj,24511851008.0,0.9855653127034505,498.9022,0.481,0.06,1970.406 +0.09,True,accelerated-peft-autogptq,55,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,52358.0,27739090432.0,lora,4,16,q_proj k_proj v_proj o_proj,12789070848.0,0.9389812151590983,281.8034,0.852,0.106,1744.195 +0.04,True,accelerated-peft-autogptq,56,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,71565.0,65895347200.0,lora,4,16,q_proj k_proj v_proj o_proj,36290144768.0,1.0755928039550782,1060.8387,0.113,0.028,463.331 +0.04,True,accelerated-peft-autogptq,57,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80387.0,45397678592.0,lora,2,16,q_proj k_proj v_proj o_proj,18649885696.0,1.0256956418355305,576.0422,0.208,0.052,426.635 +,True,accelerated-peft-autogptq,58,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,80293.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, +0.08,True,accelerated-peft-autogptq,59,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80363.0,70667573760.0,lora,4,16,q_proj k_proj v_proj o_proj,18650082304.0,1.0266701062520345,1089.3291,0.22,0.028,451.214 diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index 21b3f98a..248eacb2 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -81,4 +81,4 @@ scenarios: model_name_or_path: - 'TheBloke/Mistral-7B-v0.1-GPTQ' - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ' - - TheBloke/Nous-Hermes-Llama2-70B-GPTQ \ No newline at end of file + - 'TheBloke/Llama-2-70B-GPTQ' \ No newline at end of file diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh index 758e2d9e..e08125b3 100644 --- a/scripts/run_benchmarks.sh +++ b/scripts/run_benchmarks.sh @@ -33,10 +33,12 @@ BENCH_RESULT_FILE=benchmarks.csv # freeze the pip requirements here PIP_REQUIREMENTS_FILE=requirements.txt +# ------------- DROP COLUMNS FRO RESULTS ----------------- # env inputs DRY_RUN=${DRY_RUN:-"false"} NO_DATA_PROCESSING=${NO_DATA_PROCESSING:-"false"} NO_OVERWRITE=${NO_OVERWRITE:-"false"} +MEMORY_LOGGING=${MEMORY_LOGGING:-"huggingface"} # inputs NUM_GPUS_MATRIX=${1-"1 2"} @@ -48,6 +50,7 @@ echo "NUM_GPUS_MATRIX: $NUM_GPUS_MATRIX" echo "RESULT_DIR: $RESULT_DIR" echo "SCENARIOS_CONFIG: $SCENARIOS_CONFIG" echo "SCENARIOS_FILTER: $SCENARIOS_FILTER" +echo "MEMORY_LOGGING: $MEMORY_LOGGING" if [ -n "$RESULT_DIR" ]; then echo "The results directory is not empty. " @@ -86,6 +89,14 @@ if [ "$NO_DATA_PROCESSING" = "true" ]; then EXTRA_ARGS="$EXTRA_ARGS --no_data_processing" fi +if [ "$MEMORY_LOGGING" = "huggingface" ]; then + EXTRA_ARGS="$EXTRA_ARGS --log_memory_hf" +elif [ "$MEMORY_LOGGING" = "nvidia" ]; then + EXTRA_ARGS="$EXTRA_ARGS --log_nvidia_smi" +elif [ "$MEMORY_LOGGING" = "all" ]; then + EXTRA_ARGS="$EXTRA_ARGS --log_nvidia_smi --log_memory_hf" +fi + # dump out the environment pip freeze > $PIP_REQUIREMENTS_FILE @@ -101,6 +112,20 @@ python $WORKING_DIR/benchmark.py \ # produce the final CSV for checkin # need to set PYTHONPATH because there is an import inside # this will write to the BENCH_RESULT_FILE +# Remove the columns with values already represented by other metrics in the summary report PYTHONPATH=. \ python $WORKING_DIR/display_bench_results.py benchmark_outputs \ - --result_file $BENCH_RESULT_FILE + --result_file $BENCH_RESULT_FILE \ + --remove_columns \ + 'before_init_mem_cpu' \ + 'before_init_mem_gpu' \ + 'init_mem_cpu_alloc_delta' \ + 'init_mem_cpu_peaked_delta' \ + 'init_mem_gpu_alloc_delta' \ + 'init_mem_gpu_peaked_delta' \ + 'train_mem_cpu_alloc_delta' \ + 'train_mem_cpu_peaked_delta' \ + 'train_mem_gpu_alloc_delta' \ + 'train_mem_gpu_peaked_delta' \ + 'acceleration_framework_config_file' +