Skip to content

Commit

Permalink
Add benchmarks to CI (#479)
Browse files Browse the repository at this point in the history
Summary:
## Types of changes

- [ ] Bug fix (non-breaking change which fixes an issue)
- [X] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Docs change / refactoring / dependency upgrade

## Motivation and Context / Related issue
There's a task #368 for committing benchmark code. In this change I add these benchmarks into CI integration tests. To choose thresholds I ran the benchmarks locally on all  the layers with (batch size: 16, num_runs: 100, num_repeats: 20, forward_only: False), and generated the following report:

|              | memory*  | memory* | memory*     | memory* | memory*      | runtime                | runtime              | runtime            | runtime                | runtime            |
|--------------|---------|--------|------------|--------|-------------|------------------------|----------------------|--------------------|------------------------|--------------------|
| value        | control | dp     | dp/control | gsm    | gsm/control | control                | dp                   | dp/control         | gsm                    | gsm/control        |
| base_layer   |         |        |            |        |             |                        |                      |                    |                        |                    |
| conv         | 0.0     |        |            | 0.0    |             | 2.021756922606001      |                      |                    | 3.2889059911645036     | 1.6267563891534373 |
| embedding    | 0.0     |        |            | 0.0    |             | 0.002484286398502263   |                      |                    | 0.013664713416999803   | 5.5004581698946    |
| groupnorm    | 0.0     |        |            | 0.0    |             | 0.0001871487290072764  |                      |                    | 0.00043170701800136156 | 2.306759016165034  |
| gru          | 0.0     | 0.0    |            | 0.0    |             | 0.045029744959007065   | 0.057370035271503174 | 1.2740475284443677 | 0.2402042072270033     | 5.334345274344187  |
| instancenorm | 0.0     |        |            | 0.0    |             | 0.004493124293996517   |                      |                    | 0.006058429501005777   | 1.3483779002287433 |
| layernorm    | 0.0     |        |            | 0.0    |             | 0.00011227587499979562 |                      |                    | 0.0002241125804985131  | 1.9960884784814286 |
| linear       | 0.0     |        |            | 0.0    |             | 0.001010556231000001   |                      |                    | 0.003052972127999998   | 3.021080900148341  |
| lstm         | 0.0     | 0.0    |            | 0.0    |             | 0.052634652085002925   | 0.06508583683050075  | 1.2365586975931682 | 0.2982182763324963     | 5.665816425477371  |
| mha          | 0.0     | 0.0    |            | 0.0    |             | 0.018872260358001765   | 0.01870937360499738  | 0.9913689854890476 | 0.02688384014700477    | 1.424516175435558  |
| rnn          | 0.0     | 0.0    |            | 0.0    |             | 0.01576623683249454    | 0.02184348723049516  | 1.3854597937711604 | 0.10178373254250346    | 6.455803856296582  |

(*) This report wasn't generated on a machine with CUDA so the memory wasn't measured. Will update later when it runs in CI on a GPU machine.

Using the report and section 3 in the [paper](https://arxiv.org/pdf/2109.12298.pdf), I parameterised the runtime and memory thresholds for different layers.

## How Has This Been Tested (if it applies)
 I ran the jobs locally and generated reports.

## Checklist

- [X] The documentation is up-to-date with the changes I made.
- [X] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**).
- [ ] All tests passed, and additional code has been covered with new tests.

Pull Request resolved: #479

Differential Revision: D38999201

Pulled By: moaradwan

fbshipit-source-id: 3d02931970e39ea331674c9f0676db9e22c5edaa
  • Loading branch information
Attia Radwan authored and facebook-github-bot committed Aug 25, 2022
1 parent 79781b1 commit baea28a
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 11 deletions.
49 changes: 49 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,35 @@ commands:
- store_artifacts:
path: runs/charlstm/test-reports

benchmark_layers_integration_test:
description: "Runs benchmark end to end"
parameters:
device:
default: "cpu"
type: string
layers:
default: "mha dpmha gsm_dpmha embedding gsm_embedding instancenorm gsm_instancenorm groupnorm gsm_groupnorm layernorm gsm_layernorm lstm dplstm gsm_dplstm rnn dprnn gsm_dprnn linear gsm_linear gru dpgru gsm_dpgru"
type: string
runtime_ratio_threshold:
default: 7.0
type: float
memory_ratio_threshold:
default: 2.0
type: float
steps:
- run:
name: benchmarks
command: |
mkdir -p benchmarks/results/raw
echo "Using $(python -V) ($(which python))"
echo "Using $(pip -V) ($(which pip))"
python benchmarks/run_benchmarks.py --batch_size 16 --layers <<parameters.layers>> --config_file ./benchmarks/config.json --root ./benchmarks/results/raw/ --cont
layers<<parameters.layers>>; layers_list=(${(@s: :)layers}); mkdir /tmp/report_layers; cp -v ${:-${^layers_list}*} /tmp/report_layers
python benchmarks/generate_report.py --path-to-results /tmp/report_layers --save-path benchmarks/results/report.csv
python -c "import pandas as pd; r = pd.read_csv('.benchmarks/results/report.csv').fillna(0); th="<<parameters.runtime_ratio_threshold>>"; exit(0) if (r.loc[:, ('runtime', 'dp/control')] < th).all() and (r.loc[:, ('runtime', 'gsm/control')] < th).all() else exit(1)"
when: always
- store_artifacts:
path: benchmarks/results/
# -------------------------------------------------------------------------------------
# Jobs
# -------------------------------------------------------------------------------------
Expand Down Expand Up @@ -315,6 +344,26 @@ jobs:
device: "cuda"
- dcgan_integration_test:
device: "cuda"
- benchmark_layer_integration_test:
device: "cuda"
layers: groupnorm gsm_groupnorm gru dp_gru instancenorm gsm_instancenorm layernorm gsm_layernorm lstm dplstm mha dpmha gsm_dpmha rnn dprnn
runtime_ratio_threshold: 2.5
memory_ratio_threshold: 1.6
- benchmark_layer_integration_test:
device: "cuda"
layers: linear gsm_linear
runtime_ratio_threshold: 3.3
memory_ratio_threshold: 13
- benchmark_layer_integration_test:
device: "cuda"
layers: gru gsm_dpgru lstm gsm_dplstm rnn gsm_dprnn
runtime_ratio_threshold: 7
memory_ratio_threshold: 1.5
- benchmark_layer_integration_test:
device: "cuda"
layers: embedding gsm_embedding
runtime_ratio_threshold: 6
memory_ratio_threshold: 15

unittest_multi_gpu:
machine:
Expand Down
15 changes: 14 additions & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Do this num_runs times:
loss.backward()
Stop timer
Return elapsed time / num_repeats and memory statistics
```

Expand Down Expand Up @@ -107,6 +107,19 @@ optional arguments:
-v, --verbose
```

`generate_report.py` will take as an input the path where `run_benchmarks.py` has written the results and it will generate an CSV report.
```
usage: generate_report.py [-h] [--path-to-results PATH_TO_RESULTS]
[--save-path SAVE_PATH]
optional arguments:
-h, --help show this help message and exit
--path-to-results PATH_TO_RESULTS
the path that `run_benchmarks.py` has saved results
to.
--save-path SAVE_PATH
path to save the CSV output.
```
## Tests

```python -m pytest tests/```
Expand Down
15 changes: 6 additions & 9 deletions benchmarks/benchmark_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,12 @@ def run_layer_benchmark(
)

# benchmark.Timer performs its own warmups
try:
timer = benchmark.Timer(
stmt="benchmark_fun()",
globals={"benchmark_fun": benchmark_fun},
num_threads=1,
)
runtime = timer.timeit(num_repeats).mean
except RuntimeError:
runtime = float("nan")
timer = benchmark.Timer(
stmt="benchmark_fun()",
globals={"benchmark_fun": benchmark_fun},
num_threads=1,
)
runtime = timer.timeit(num_repeats).mean

# get max memory allocated and reset memory statistics
memory_stats["max_memory"] = reset_peak_memory_stats(device).prev_max_mem
Expand Down
36 changes: 36 additions & 0 deletions benchmarks/generate_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse


from utils import generate_report

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--path-to-results",
default="./results/raw",
type=str,
help="the path that `run_benchmarks.py` has saved results to.",
)
parser.add_argument(
"--save-path",
default="./results/report.csv",
type=str,
help="path to save the CSV output.",
)
args = parser.parse_args()

generate_report(args.path_to_results, args.save_path)
75 changes: 74 additions & 1 deletion benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

import pickle
import glob
from collections import namedtuple
from typing import Any, Dict, List, Optional

import torch
from layers import LayerType

import numpy as np
import pandas as pd

Memory = namedtuple("Memory", "prev_max_mem, cur_mem")

Expand Down Expand Up @@ -163,3 +165,74 @@ def save_results(
handle,
protocol=pickle.HIGHEST_PROTOCOL,
)


def generate_report(path_to_results: str, save_path: str) -> None:
"""Generate a report from the benchamrks outcome.
The output is a csv file whic contains the runtime and memory of each layer.
If multiple layer variants were run (pytorch nn, DP, or GSM).
Then we will compare the performance of both DP and GSM to pytorch.nn.
Args:
path_to_results: the path that `run_benchmarks.py` has saved results to.
save_path: path to save the CSV output.
"""
path_to_results = (
path_to_results if path_to_results[-1] != "/" else path_to_results[:-1]
)
files = glob.glob(f"{path_to_results}/*")

if len(files) == 0:
raise Exception(f"There were no result files in the path {path_to_results}")

raw_results = []
for result_file in files:
with open(result_file, "rb") as handle:
raw_results.append(pickle.load(handle))

results_dict = []
for raw in raw_results:
runtime = np.mean([i["runtime"] for i in raw["results"]])
memory = np.mean([i["memory_stats"]["max_memory"] for i in raw["results"]])
result = {
"layer": raw["layer"],
"batch_size": raw["batch_size"],
"num_runs": raw["num_runs"],
"num_repeats": raw["num_repeats"],
"forward_only": raw["forward_only"],
"runtime": runtime,
"memory": memory,
}
results_dict.append(result)

results = pd.DataFrame(results_dict)
results["variant"] = "control"
results["variant"][results["layer"].str.startswith("gsm")] = "gsm"
results["variant"][results["layer"].str.startswith("dp")] = "dp"
results["base_layer"] = results["layer"].str.replace("(gsm_)|(dp)", "")

pivot = results.pivot_table(
index=["batch_size", "num_runs", "num_repeats", "forward_only", "base_layer"],
columns=["variant"],
values=["runtime", "memory"],
)

if "control" in results["variant"].tolist():
pivot.columns = pivot.columns.set_names("value", level=1)
pivot[("runtime", "gsm/control")] = (
pivot.loc[:, ("runtime", "gsm")] / pivot.loc[:, ("runtime", "control")]
)
pivot[("runtime", "dp/control")] = (
pivot.loc[:, ("runtime", "dp")] / pivot.loc[:, ("runtime", "control")]
)
pivot[("memory", "gsm/control")] = (
pivot.loc[:, ("memory", "gsm")] / pivot.loc[:, ("memory", "control")]
)
pivot[("memory", "dp/control")] = (
pivot.loc[:, ("memory", "dp")] / pivot.loc[:, ("memory", "control")]
)

pivot.sort_index(axis=1).sort_values(
["batch_size", "num_runs", "num_repeats", "forward_only"]
).to_csv(save_path)

0 comments on commit baea28a

Please sign in to comment.