-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Benchmark] Memory benchmark utils (#4198)
* improve memory benchmarking * correct typo * fix current memory * check torch memory allocated * better pytorch function * add total cached gpu memory * add total gpu required * improve torch gpu usage * update memory usage * finalize memory tracing * save intermediate benchmark class * fix conflict * improve benchmark * improve benchmark * finalize * make style * improve benchmarking * correct typo * make train function more flexible * fix csv save * better repr of bytes * better print * fix __repr__ bug * finish plot script * rename plot file * delete csv and small improvements * fix in plot * fix in plot * correct usage of timeit * remove redundant line * remove redundant line * fix bug * add hf parser tests * add versioning and platform info * make style * add gpu information * ensure backward compatibility * finish adding all tests * Update src/transformers/benchmark/benchmark_args.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/benchmark/benchmark_args_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * delete csv files * fix isort ordering * add out of memory handling * add better train memory handling Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
- Loading branch information
1 parent
ec4cdfd
commit 96f57c9
Showing
14 changed files
with
934 additions
and
744 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import csv | ||
from collections import defaultdict | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
import numpy as np | ||
|
||
import matplotlib.pyplot as plt | ||
from transformers import HfArgumentParser | ||
|
||
|
||
@dataclass | ||
class PlotArguments: | ||
""" | ||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | ||
""" | ||
|
||
csv_file: str = field(metadata={"help": "The csv file to plot."},) | ||
plot_along_batch: bool = field( | ||
default=False, | ||
metadata={"help": "Whether to plot along batch size or sequence lengh. Defaults to sequence length."}, | ||
) | ||
is_time: bool = field( | ||
default=False, | ||
metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."}, | ||
) | ||
is_train: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether the csv file has training results or inference results. Defaults to inference results." | ||
}, | ||
) | ||
figure_png_file: Optional[str] = field( | ||
default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."}, | ||
) | ||
|
||
|
||
class Plot: | ||
def __init__(self, args): | ||
self.args = args | ||
self.result_dict = defaultdict(lambda: dict(bsz=[], seq_len=[], result={})) | ||
|
||
with open(self.args.csv_file, newline="") as csv_file: | ||
reader = csv.DictReader(csv_file) | ||
for row in reader: | ||
model_name = row["model"] | ||
self.result_dict[model_name]["bsz"].append(int(row["batch_size"])) | ||
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"])) | ||
self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = row[ | ||
"result" | ||
] | ||
|
||
def plot(self): | ||
fig, ax = plt.subplots() | ||
title_str = "Time usage" if self.args.is_time else "Memory usage" | ||
title_str = title_str + " for training" if self.args.is_train else title_str + " for inference" | ||
|
||
for model_name in self.result_dict.keys(): | ||
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"]))) | ||
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"]))) | ||
results = self.result_dict[model_name]["result"] | ||
|
||
(x_axis_array, inner_loop_array) = ( | ||
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes) | ||
) | ||
|
||
plt.xlim(min(x_axis_array), max(x_axis_array)) | ||
|
||
for inner_loop_value in inner_loop_array: | ||
if self.args.plot_along_batch: | ||
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int) | ||
else: | ||
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32) | ||
|
||
ax.set_xscale("log", basex=2) | ||
ax.set_yscale("log", basey=10) | ||
|
||
(x_axis_label, inner_loop_label) = ( | ||
("batch_size", "sequence_length in #tokens") | ||
if self.args.plot_along_batch | ||
else ("sequence_length in #tokens", "batch_size") | ||
) | ||
|
||
x_axis_array = np.asarray(x_axis_array, np.int) | ||
plt.scatter(x_axis_array, y_axis_array, label=f"{model_name} - {inner_loop_label}: {inner_loop_value}") | ||
plt.plot(x_axis_array, y_axis_array, "--") | ||
|
||
title_str += f" {model_name} vs." | ||
|
||
title_str = title_str[:-4] | ||
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB" | ||
|
||
# plot | ||
plt.title(title_str) | ||
plt.xlabel(x_axis_label) | ||
plt.ylabel(y_axis_label) | ||
plt.legend() | ||
|
||
if self.args.figure_png_file is not None: | ||
plt.savefig(self.args.figure_png_file) | ||
else: | ||
plt.show() | ||
|
||
|
||
def main(): | ||
parser = HfArgumentParser(PlotArguments) | ||
plot_args = parser.parse_args_into_dataclasses()[0] | ||
plot = Plot(args=plot_args) | ||
plot.plot() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The HuggingFace Inc. team. | ||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" Benchmarking the library on inference and training """ | ||
|
||
from transformers import HfArgumentParser, PyTorchBenchmark, PyTorchBenchmarkArguments | ||
|
||
|
||
def main(): | ||
parser = HfArgumentParser(PyTorchBenchmarkArguments) | ||
benchmark_args = parser.parse_args_into_dataclasses()[0] | ||
benchmark = PyTorchBenchmark(args=benchmark_args) | ||
benchmark.run() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.