Skip to content

Commit 2b130ec

Browse files
committed
feat: Add functionality for easily benchmarking fx
- Add fx path in benchmarking code - Add fx saving tools to `utils` and `hub` - Add PyTorch model parsing and loading in `perf_run` script
1 parent 2b1cedf commit 2b130ec

File tree

4 files changed

+104
-22
lines changed

4 files changed

+104
-22
lines changed

tools/perf/benchmark.sh

+6-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ echo "Benchmarking VGG16 model"
1212
for bs in ${batch_sizes[@]}
1313
do
1414
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
15+
--model_torch ${MODELS_DIR}/vgg16_pytorch.pt \
1516
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
1617
--batch_size ${bs} \
17-
--backends torch,torch_tensorrt,tensorrt \
18+
--backends torch,torch_tensorrt,tensorrt,fx2trt \
1819
--report "vgg_perf_bs${bs}.txt"
1920
done
2021

@@ -23,9 +24,10 @@ echo "Benchmarking Resnet50 model"
2324
for bs in ${batch_sizes[@]}
2425
do
2526
python perf_run.py --model ${MODELS_DIR}/resnet50_scripted.jit.pt \
27+
--model_torch ${MODELS_DIR}/resnet50_pytorch.pt \
2628
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
2729
--batch_size ${bs} \
28-
--backends torch,torch_tensorrt,tensorrt \
30+
--backends torch,torch_tensorrt,tensorrt,fx2trt \
2931
--report "rn50_perf_bs${bs}.txt"
3032
done
3133

@@ -45,9 +47,10 @@ echo "Benchmarking EfficientNet-B0 model"
4547
for bs in ${batch_sizes[@]}
4648
do
4749
python perf_run.py --model ${MODELS_DIR}/efficientnet_b0_scripted.jit.pt \
50+
--model_torch ${MODELS_DIR}/efficientnet_b0_pytorch.pt \
4851
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
4952
--batch_size ${bs} \
50-
--backends torch,torch_tensorrt,tensorrt \
53+
--backends torch,torch_tensorrt,tensorrt,fx2trt \
5154
--report "eff_b0_perf_bs${bs}.txt"
5255
done
5356

tools/perf/hub.py

+63-12
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,19 @@
2121
# Downloads all model files again if manifest file is not present
2222
MANIFEST_FILE = "model_manifest.json"
2323

24+
# Valid paths for model-saving specification
25+
VALID_PATHS = ("script", "trace", "torchscript", "pytorch", "all")
26+
27+
# Key models selected for benchmarking with their respective paths
2428
BENCHMARK_MODELS = {
25-
"vgg16": {"model": models.vgg16(weights=None), "path": "script"},
26-
"resnet50": {"model": models.resnet50(weights=None), "path": "script"},
29+
"vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]},
30+
"resnet50": {
31+
"model": models.resnet50(weights=None),
32+
"path": ["script", "pytorch"],
33+
},
2734
"efficientnet_b0": {
2835
"model": timm.create_model("efficientnet_b0", pretrained=True),
29-
"path": "script",
36+
"path": ["script", "pytorch"],
3037
},
3138
"vit": {
3239
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
@@ -40,18 +47,26 @@ def get(n, m, manifest):
4047
print("Downloading {}".format(n))
4148
traced_filename = "models/" + n + "_traced.jit.pt"
4249
script_filename = "models/" + n + "_scripted.jit.pt"
50+
pytorch_filename = "models/" + n + "_pytorch.pt"
4351
x = torch.ones((1, 3, 300, 300)).cuda()
44-
if n == "bert-base-uncased":
52+
if n == "bert_base_uncased":
4553
traced_model = m["model"]
4654
torch.jit.save(traced_model, traced_filename)
4755
manifest.update({n: [traced_filename]})
4856
else:
4957
m["model"] = m["model"].eval().cuda()
50-
if m["path"] == "both" or m["path"] == "trace":
58+
59+
# Get all desired model save specifications as list
60+
paths = [m["path"]] if isinstance(m["path"], str) else m["path"]
61+
62+
# Depending on specified model save specifications, save desired model formats
63+
if any(path in ("all", "torchscript", "trace") for path in paths):
64+
# (TorchScript) Traced model
5165
trace_model = torch.jit.trace(m["model"], [x])
5266
torch.jit.save(trace_model, traced_filename)
5367
manifest.update({n: [traced_filename]})
54-
if m["path"] == "both" or m["path"] == "script":
68+
if any(path in ("all", "torchscript", "script") for path in paths):
69+
# (TorchScript) Scripted model
5570
script_model = torch.jit.script(m["model"])
5671
torch.jit.save(script_model, script_filename)
5772
if n in manifest.keys():
@@ -60,6 +75,15 @@ def get(n, m, manifest):
6075
manifest.update({n: files})
6176
else:
6277
manifest.update({n: [script_filename]})
78+
if any(path in ("all", "pytorch") for path in paths):
79+
# (PyTorch Module) model
80+
torch.save(m["model"], pytorch_filename)
81+
if n in manifest.keys():
82+
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
83+
files.append(script_filename)
84+
manifest.update({n: files})
85+
else:
86+
manifest.update({n: [script_filename]})
6387
return manifest
6488

6589

@@ -72,15 +96,35 @@ def download_models(version_matches, manifest):
7296
for n, m in BENCHMARK_MODELS.items():
7397
scripted_filename = "models/" + n + "_scripted.jit.pt"
7498
traced_filename = "models/" + n + "_traced.jit.pt"
99+
pytorch_filename = "models/" + n + "_pytorch.pt"
75100
# Check if model file exists on disk
101+
102+
# Extract model specifications as list and ensure all desired formats exist
103+
paths = [m["path"]] if isinstance(m["path"], str) else m["path"]
76104
if (
77105
(
78-
m["path"] == "both"
106+
any(path == "all" for path in paths)
107+
and os.path.exists(scripted_filename)
108+
and os.path.exists(traced_filename)
109+
and os.path.exists(pytorch_filename)
110+
)
111+
or (
112+
any(path == "torchscript" for path in paths)
79113
and os.path.exists(scripted_filename)
80114
and os.path.exists(traced_filename)
81115
)
82-
or (m["path"] == "script" and os.path.exists(scripted_filename))
83-
or (m["path"] == "trace" and os.path.exists(traced_filename))
116+
or (
117+
any(path == "script" for path in paths)
118+
and os.path.exists(scripted_filename)
119+
)
120+
or (
121+
any(path == "trace" for path in paths)
122+
and os.path.exists(traced_filename)
123+
)
124+
or (
125+
any(path == "pytorch" for path in paths)
126+
and os.path.exists(pytorch_filename)
127+
)
84128
):
85129
print("Skipping {} ".format(n))
86130
continue
@@ -90,7 +134,6 @@ def download_models(version_matches, manifest):
90134
def main():
91135
manifest = None
92136
version_matches = False
93-
manifest_exists = False
94137

95138
# Check if Manifest file exists or is empty
96139
if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0:
@@ -99,7 +142,6 @@ def main():
99142
# Creating an empty manifest file for overwriting post setup
100143
os.system("touch {}".format(MANIFEST_FILE))
101144
else:
102-
manifest_exists = True
103145

104146
# Load manifest if already exists
105147
with open(MANIFEST_FILE, "r") as f:
@@ -129,4 +171,13 @@ def main():
129171
f.truncate()
130172

131173

132-
main()
174+
if __name__ == "__main__":
175+
# Ensure all specified desired model formats exist and are valid
176+
paths = [
177+
[m["path"]] if isinstance(m["path"], str) else m["path"]
178+
for m in BENCHMARK_MODELS.values()
179+
]
180+
assert all(
181+
(path in VALID_PATHS) for path_list in paths for path in path_list
182+
), "Not all 'path' attributes in BENCHMARK_MODELS are valid"
183+
main()

tools/perf/perf_run.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import time
66
import timeit
7+
import warnings
78
import numpy as np
89
import torch.backends.cudnn as cudnn
910

@@ -147,6 +148,7 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size):
147148
max_batch_size=batch_size,
148149
lower_precision=precision,
149150
verbose_log=False,
151+
explicit_batch_dimension=True,
150152
)
151153
end_compile = time.time_ns()
152154
compile_time_ms = (end_compile - start_compile) / 1e6
@@ -272,6 +274,7 @@ def run(
272274
truncate_long_and_double=False,
273275
batch_size=1,
274276
is_trt_engine=False,
277+
model_torch=None,
275278
):
276279
for backend in backends:
277280
if precision == "int8":
@@ -323,7 +326,13 @@ def run(
323326
)
324327

325328
elif backend == "fx2trt":
326-
run_fx2trt(model, input_tensors, params, precision, batch_size)
329+
if model_torch is None:
330+
warnings.warn(
331+
"Requested backend fx2trt without specifying a PyTorch Model, "
332+
+ "skipping this backend"
333+
)
334+
continue
335+
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)
327336

328337
elif backend == "tensorrt":
329338
run_tensorrt(
@@ -399,7 +408,13 @@ def load_model(params):
399408
type=str,
400409
help="Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt",
401410
)
402-
arg_parser.add_argument("--model", type=str, help="Name of the model file")
411+
arg_parser.add_argument("--model", type=str, help="Name of torchscript model file")
412+
arg_parser.add_argument(
413+
"--model_torch",
414+
type=str,
415+
default="",
416+
help="Name of torch model file (used for fx2trt)",
417+
)
403418
arg_parser.add_argument(
404419
"--inputs",
405420
type=str,
@@ -491,16 +506,28 @@ def load_model(params):
491506
else:
492507
params = vars(args)
493508
model_name = params["model"]
509+
model = None
510+
511+
model_name_torch = params["model_torch"]
512+
model_torch = None
513+
514+
# Load TorchScript model
494515
if os.path.exists(model_name):
495-
print("Loading user provided model: ", model_name)
516+
print("Loading user provided torchscript model: ", model_name)
496517
model = torch.jit.load(model_name).cuda().eval()
497518
elif model_name in BENCHMARK_MODELS:
519+
print("Loading torchscript model from BENCHMARK_MODELS for: ", model_name)
498520
model = BENCHMARK_MODELS[model_name]["model"].eval().cuda()
499521
else:
500522
raise ValueError(
501523
"Invalid model name. Please provide a torchscript model file or model name (among the following options vgg16|resnet50|efficientnet_b0|vit)"
502524
)
503525

526+
# Load PyTorch Model, if provided
527+
if len(model_name_torch) > 0 and os.path.exists(model_name_torch):
528+
print("Loading user provided torch model: ", model_name_torch)
529+
model_torch = torch.load(model_name_torch).eval().cuda()
530+
504531
backends = parse_backends(params["backends"])
505532
truncate_long_and_double = params["truncate"]
506533
batch_size = params["batch_size"]
@@ -523,6 +550,7 @@ def load_model(params):
523550
truncate_long_and_double,
524551
batch_size,
525552
is_trt_engine,
553+
model_torch=model_torch,
526554
)
527555

528556
# Generate report

tools/perf/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import timm
66

77
BENCHMARK_MODELS = {
8-
"vgg16": {"model": models.vgg16(pretrained=True), "path": "script"},
8+
"vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]},
99
"resnet50": {
10-
"model": torch.hub.load("pytorch/vision:v0.9.0", "resnet50", pretrained=True),
11-
"path": "script",
10+
"model": models.resnet50(weights=None),
11+
"path": ["script", "pytorch"],
1212
},
1313
"efficientnet_b0": {
1414
"model": timm.create_model("efficientnet_b0", pretrained=True),
15-
"path": "script",
15+
"path": ["script", "pytorch"],
1616
},
1717
"vit": {
1818
"model": timm.create_model("vit_base_patch16_224", pretrained=True),

0 commit comments

Comments
 (0)