Skip to content

Commit

Permalink
[Benchmark] Memory benchmark utils (#4198)
Browse files Browse the repository at this point in the history
* improve memory benchmarking

* correct typo

* fix current memory

* check torch memory allocated

* better pytorch function

* add total cached gpu memory

* add total gpu required

* improve torch gpu usage

* update memory usage

* finalize memory tracing

* save intermediate benchmark class

* fix conflict

* improve benchmark

* improve benchmark

* finalize

* make style

* improve benchmarking

* correct typo

* make train function more flexible

* fix csv save

* better repr of bytes

* better print

* fix __repr__ bug

* finish plot script

* rename plot file

* delete csv and small improvements

* fix in plot

* fix in plot

* correct usage of timeit

* remove redundant line

* remove redundant line

* fix bug

* add hf parser tests

* add versioning and platform info

* make style

* add gpu information

* ensure backward compatibility

* finish adding all tests

* Update src/transformers/benchmark/benchmark_args.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/benchmark/benchmark_args_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* delete csv files

* fix isort ordering

* add out of memory handling

* add better train memory handling

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
  • Loading branch information
patrickvonplaten and LysandreJik authored May 27, 2020
1 parent ec4cdfd commit 96f57c9
Show file tree
Hide file tree
Showing 14 changed files with 934 additions and 744 deletions.
113 changes: 113 additions & 0 deletions examples/benchmarking/plot_csv_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import csv
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Optional

import numpy as np

import matplotlib.pyplot as plt
from transformers import HfArgumentParser


@dataclass
class PlotArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""

csv_file: str = field(metadata={"help": "The csv file to plot."},)
plot_along_batch: bool = field(
default=False,
metadata={"help": "Whether to plot along batch size or sequence lengh. Defaults to sequence length."},
)
is_time: bool = field(
default=False,
metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."},
)
is_train: bool = field(
default=False,
metadata={
"help": "Whether the csv file has training results or inference results. Defaults to inference results."
},
)
figure_png_file: Optional[str] = field(
default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
)


class Plot:
def __init__(self, args):
self.args = args
self.result_dict = defaultdict(lambda: dict(bsz=[], seq_len=[], result={}))

with open(self.args.csv_file, newline="") as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
model_name = row["model"]
self.result_dict[model_name]["bsz"].append(int(row["batch_size"]))
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"]))
self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = row[
"result"
]

def plot(self):
fig, ax = plt.subplots()
title_str = "Time usage" if self.args.is_time else "Memory usage"
title_str = title_str + " for training" if self.args.is_train else title_str + " for inference"

for model_name in self.result_dict.keys():
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
results = self.result_dict[model_name]["result"]

(x_axis_array, inner_loop_array) = (
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
)

plt.xlim(min(x_axis_array), max(x_axis_array))

for inner_loop_value in inner_loop_array:
if self.args.plot_along_batch:
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int)
else:
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32)

ax.set_xscale("log", basex=2)
ax.set_yscale("log", basey=10)

(x_axis_label, inner_loop_label) = (
("batch_size", "sequence_length in #tokens")
if self.args.plot_along_batch
else ("sequence_length in #tokens", "batch_size")
)

x_axis_array = np.asarray(x_axis_array, np.int)
plt.scatter(x_axis_array, y_axis_array, label=f"{model_name} - {inner_loop_label}: {inner_loop_value}")
plt.plot(x_axis_array, y_axis_array, "--")

title_str += f" {model_name} vs."

title_str = title_str[:-4]
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB"

# plot
plt.title(title_str)
plt.xlabel(x_axis_label)
plt.ylabel(y_axis_label)
plt.legend()

if self.args.figure_png_file is not None:
plt.savefig(self.args.figure_png_file)
else:
plt.show()


def main():
parser = HfArgumentParser(PlotArguments)
plot_args = parser.parse_args_into_dataclasses()[0]
plot = Plot(args=plot_args)
plot.plot()


if __name__ == "__main__":
main()
29 changes: 29 additions & 0 deletions examples/benchmarking/run_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the library on inference and training """

from transformers import HfArgumentParser, PyTorchBenchmark, PyTorchBenchmarkArguments


def main():
parser = HfArgumentParser(PyTorchBenchmarkArguments)
benchmark_args = parser.parse_args_into_dataclasses()[0]
benchmark = PyTorchBenchmark(args=benchmark_args)
benchmark.run()


if __name__ == "__main__":
main()
Loading

0 comments on commit 96f57c9

Please sign in to comment.