Skip to content

Commit 360f6c4

Browse files
committed
Add support for FX in all backends path
- Update documentation - Add new backend for only-TorchScript benchmarks
1 parent f1bf283 commit 360f6c4

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

tools/perf/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ There are two sample configuration files added.
6666

6767
| Name | Supported Values | Description |
6868
| ----------------- | ------------------------------------ | ------------------------------------------------------------ |
69-
| backend | all, torch, torch_tensorrt, tensorrt, fx2trt | Supported backends for inference. |
69+
| backend | all, torchscript, fx2trt, torch, torch_tensorrt, tensorrt | Supported backends for inference. "all" implies the last four methods in the list at left, and "torchscript" implies the last three (excludes fx path) |
7070
| input | - | Input binding names. Expected to list shapes of each input bindings |
7171
| model | - | Configure the model filename and name |
7272
| model_torch | - | Name of torch model file and name (used for fx2trt) (optional) |
@@ -113,7 +113,7 @@ Note:
113113

114114
Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module
115115

116-
* `--backends` : Comma separated string of backends. Eg: torch, torch_tensorrt, tensorrt or fx2trt
116+
* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt,fx2trt
117117
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`)
118118
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if fx2trt is a chosen backend)
119119
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT

tools/perf/perf_run.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,13 @@ def run(
299299
)
300300
continue
301301

302+
if (model_torch is None) and (backend in ("all", "fx2trt")):
303+
warnings.warn(
304+
f"Requested backend {backend} without specifying a PyTorch Model, "
305+
+ "skipping this backend"
306+
)
307+
continue
308+
302309
if backend == "all":
303310
run_torch(model, input_tensors, params, precision, batch_size)
304311
run_torch_tensorrt(
@@ -318,6 +325,27 @@ def run(
318325
is_trt_engine,
319326
batch_size,
320327
)
328+
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)
329+
330+
elif backend == "torchscript":
331+
run_torch(model, input_tensors, params, precision, batch_size)
332+
run_torch_tensorrt(
333+
model,
334+
input_tensors,
335+
params,
336+
precision,
337+
truncate_long_and_double,
338+
batch_size,
339+
)
340+
run_tensorrt(
341+
model,
342+
input_tensors,
343+
params,
344+
precision,
345+
truncate_long_and_double,
346+
is_trt_engine,
347+
batch_size,
348+
)
321349

322350
elif backend == "torch":
323351
run_torch(model, input_tensors, params, precision, batch_size)
@@ -333,12 +361,6 @@ def run(
333361
)
334362

335363
elif backend == "fx2trt":
336-
if model_torch is None:
337-
warnings.warn(
338-
"Requested backend fx2trt without specifying a PyTorch Model, "
339-
+ "skipping this backend"
340-
)
341-
continue
342364
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)
343365

344366
elif backend == "tensorrt":

0 commit comments

Comments
 (0)