Skip to content

feat: Implement Dynamic shapes + fallback support for export path #2271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 87 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
f1f202e
feat: Move tracing to use aot export apis
peri044 Aug 8, 2023
abaf047
chore: minor changes
peri044 Aug 9, 2023
bb1f3cf
chore: minor changes
peri044 Aug 11, 2023
3d05b4d
chore: Rebase with main
peri044 Aug 11, 2023
8d99be5
chore: rebase
peri044 Aug 16, 2023
0aad214
chore: minor logging updates
peri044 Aug 17, 2023
8899735
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
8af2627
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
f6969be
Key op fixes for failing tests
gs-olive Aug 5, 2023
bad1594
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
db56dd6
chore: Move to new export APIs
peri044 Aug 17, 2023
bf961f5
chore: rebase with dynamo_tensor_freeze branch
peri044 Aug 17, 2023
b13aa82
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
dd95620
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
6bd3c64
Key op fixes for failing tests
gs-olive Aug 5, 2023
248073f
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
3e5f434
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
6bf6945
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
3b6e1e7
Key op fixes for failing tests
gs-olive Aug 5, 2023
2107d8e
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
fd5a41e
chore: add BERT test case
peri044 Aug 18, 2023
f047651
chore: remove pdb
peri044 Aug 21, 2023
de9795d
feat: Implement dynamic shapes feature
peri044 Aug 23, 2023
ab76c0d
chore: rebase
peri044 Aug 23, 2023
5f2a4f3
chore: minor update
peri044 Aug 23, 2023
566fbb0
Merge branch 'export_prototype' into dyn_export
peri044 Aug 23, 2023
4949549
chore: refactor
peri044 Aug 23, 2023
e4df382
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
d022f4a
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
9610ba7
Key op fixes for failing tests
gs-olive Aug 5, 2023
e19aae7
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
d95c360
chore: Add constraints for dynamic inputs during export
peri044 Aug 25, 2023
2860be6
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 Aug 25, 2023
d7f2477
chore: rebase with export_prototype
peri044 Aug 25, 2023
b50d362
chore: enable truncate long and double inputs
peri044 Aug 25, 2023
91b47fb
chore: updates
peri044 Aug 29, 2023
51266db
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
2005db7
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
a8cb1fe
fix: Move tracer code into try/except
gs-olive Aug 29, 2023
7ff9309
Custom implementation of AOT for compile
gs-olive Aug 29, 2023
692921e
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
e926724
chore: rebase
peri044 Sep 5, 2023
0cfd23b
Merge branch 'export_prototype' into dyn_export
peri044 Sep 5, 2023
2c85bc7
chore: minor changes
peri044 Sep 6, 2023
1de79b3
chore: add device updates
peri044 Sep 6, 2023
33ddf46
chore: minor updates
peri044 Sep 7, 2023
39e7d98
chore: refactor prepare_inputs
peri044 Sep 7, 2023
760eda6
chore: minor updates
peri044 Sep 7, 2023
b3c9666
chore: updates
peri044 Sep 7, 2023
fbfb8ef
chore: updates
peri044 Sep 7, 2023
b7056a1
chore: add tests and update GHA
peri044 Sep 8, 2023
27681c2
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
056cbf3
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
ece276c
fix: Move tracer code into try/except
gs-olive Aug 29, 2023
73a0bce
Custom implementation of AOT for compile
gs-olive Aug 29, 2023
890ba72
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
980dc1c
chore: rebase
peri044 Sep 9, 2023
dfc4899
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
09b099a
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 Sep 9, 2023
157bb2d
chore: updates
peri044 Sep 9, 2023
0005a31
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
5526bca
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 Sep 11, 2023
3420fb0
chore: updates
peri044 Sep 11, 2023
4a0afd3
Merge branch 'export_prototype' into dyn_export
peri044 Sep 11, 2023
399f929
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
4b44ff2
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
a94a075
fix: Move tracer code into try/except
gs-olive Aug 29, 2023
4e308f1
Custom implementation of AOT for compile
gs-olive Aug 29, 2023
95d3f98
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
529262a
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 Sep 11, 2023
16d19e8
Merge branch 'export_prototype' into dyn_export
peri044 Sep 11, 2023
aee529b
chore: rebase
peri044 Sep 12, 2023
89acc8e
Merge branch 'export_prototype' into dyn_export
peri044 Sep 12, 2023
6858c0a
chore: rebase
peri044 Sep 18, 2023
695dc9b
chore: address review comments
peri044 Sep 18, 2023
a5cdd24
chore: updates
peri044 Sep 18, 2023
24922fc
chore: updates
peri044 Sep 24, 2023
645816e
chore: updates
peri044 Sep 24, 2023
708ac64
chore: rebase with main
peri044 Sep 25, 2023
9bcaf49
chore: update docs
peri044 Sep 26, 2023
6704cb7
chore: update docs
peri044 Sep 26, 2023
560c779
chore: update docs
peri044 Sep 26, 2023
03f5f2d
chore: rebase
peri044 Sep 30, 2023
0349810
chore: fix tests
peri044 Oct 1, 2023
912fcab
chore: updates
peri044 Oct 2, 2023
31c09b2
chore: revert harness tracer changes
peri044 Oct 2, 2023
9f0a589
chore: address review comments
peri044 Oct 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ jobs:
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_dyn_models.py
popd

tests-py-torch-compile-be:
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/torch_tensorrt/torch_tensorrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DataType {
enum Value : int8_t {
/// INT64
kLong,
/// FP64
kDouble,
/// FP32
kFloat,
/// FP16
Expand Down
8 changes: 7 additions & 1 deletion cpp/src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ at::ScalarType toAtenDataType(DataType value) {
return at::kInt;
case DataType::kLong:
return at::kLong;
case DataType::kDouble:
return at::kDouble;
case DataType::kBool:
return at::kBool;
case DataType::kFloat:
Expand All @@ -119,7 +121,8 @@ nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {

DataType::DataType(c10::ScalarType t) {
TORCHTRT_CHECK(
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kInt || t == at::kBool,
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kDouble || t == at::kInt ||
t == at::kBool,
"Data type is unsupported (" << t << ")");
switch (t) {
case at::kHalf:
Expand All @@ -134,6 +137,9 @@ DataType::DataType(c10::ScalarType t) {
case at::kLong:
value = DataType::kLong;
break;
case at::kDouble:
value = DataType::kDouble;
break;
case at::kBool:
value = DataType::kBool;
break;
Expand Down
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ User Guide
* :ref:`getting_started_with_fx`
* :ref:`ptq`
* :ref:`runtime`
* :ref:`dynamic_shapes`
* :ref:`use_from_pytorch`
* :ref:`using_dla`

Expand All @@ -54,6 +55,7 @@ User Guide
user_guide/getting_started_with_fx_path
user_guide/ptq
user_guide/runtime
user_guide/dynamic_shapes
user_guide/use_from_pytorch
user_guide/using_dla

Expand Down
218 changes: 218 additions & 0 deletions docsrc/user_guide/dynamic_shapes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
.. _runtime:

Dynamic shapes with Torch-TensorRT
====================================

By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly.
However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model.
In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for
these range of input shapes. An example usage of static and dynamic shapes is as follows.

NOTE: The following code uses dynamo IR. Incase of Torchscript IR, please swap out ``ir=dynamo`` with ``ir=ts`` and the behavior is exactly the same.

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
# Compile with static shapes
inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32)
# or compile with dynamic shapes
inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224],
opt_shape=[4, 3, 224, 224],
max_shape=[8, 3, 224, 224],
dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)

Under the hood
--------------

There are two phases of compilation when we use ``torch_tensorrt.compile`` API with ``ir=dynamo`` (default).

- aten_tracer.trace (which uses torch.export to trace the graph with the given inputs)

In the tracing phase, we use torch.export along with the constraints. In the case of
dynamic shaped inputs, the range can be provided to the tracing via constraints. Please
refer to this `docstring <https://github.com/pytorch/pytorch/blob/5dcee01c2b89f6bedeef9dd043fd8d6728286582/torch/export/__init__.py#L372-L434>`_
for detailed information on how to set constraints. In short, we create new inputs for
torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take.
Please take a look at ``aten_tracer.py`` file to understand how this works under the hood.

- dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT)

In the conversion to TensorRT, we use the user provided dynamic shape inputs.
We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the
intermediate output shapes which can be used in case the graph has a mix of Pytorch
and TensorRT submodules.

Custom Constraints
------------------

Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``,
Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows

.. code-block:: python

for dim in constraint_dims:
if min_shape[dim] > 1:
constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim))
if max_shape[dim] > 1:
constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim])

Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them.
For example, in the case of BERT model compilation, there are two inputs and a constraint has to be set involving the sequence length size of these two inputs.

.. code-block:: python

constraints.append(dynamic_dim(trace_inputs[0], 0) == dynamic_dim(trace_inputs[1], 0))


If you have to provide any custom constraints to your model, the overall workflow for model compilation using ``ir=dynamo`` would involve a few steps.

.. code-block:: python

import torch
import torch_tensorrt
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
# Assume the model has two inputs
model = MyModel()
torch_input_1 = torch.randn((1, 14), dtype=torch.int32).cuda()
torch_input_2 = torch.randn((1, 14), dtype=torch.int32).cuda()

dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14],
opt_shape=[4, 14],
max_shape=[8, 14],
dtype=torch.int32),
torch_tensorrt.Input(min_shape=[1, 14],
opt_shape=[4, 14],
max_shape=[8, 14],
dtype=torch.int32)]

# Export the model with additional constraints
constraints = []
# The following constraints are automatically added by Torch-TensorRT in the
# general case when you call torch_tensorrt.compile directly on MyModel()
constraints.append(dynamic_dim(torch_input_1, 0) < 8)
constraints.append(dynamic_dim(torch_input_2, 0) < 8)
# This is an additional constraint as instructed by Torchdynamo
constraints.append(dynamic_dim(torch_input_1, 0) == dynamic_dim(torch_input_2, 0))
with unittest.mock.patch(
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
):
graph_module = export(
model, (torch_input_1, torch_input_2), constraints=constraints
).module()

# Use the dynamo.compile API
trt_mod = torch_tensorrt.dynamo.compile(graph_module, inputs=dynamic_inputs, **compile_spec)

Limitations
-----------

If there are operations in the graph that use the dynamic dimension of the input, Pytorch
introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and
the compilation results in undefined behavior. We plan to add support for these operators and implement
robust support for shape tensors in the next release. Here is an example of the limitation described above

.. code-block:: python

import torch
import torch_tensorrt

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))

def forward(self, x):
x = self.avgpool(x)
out = torch.flatten(x, 1)
return out

model = MyModel().eval().cuda()
# Compile with dynamic shapes
inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1),
opt_shape=(4, 512, 1, 1),
max_shape=(8, 512, 1, 1),
dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)


The traced graph of `MyModule()` looks as follows

.. code-block:: python

Post export graph: graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%arg0_1, [-1, -2], True), kwargs = {})
%sym_size : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {})
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mean, [%sym_size, 512]), kwargs = {})
return (view,)


Here the ``%sym_size`` node captures the dynamic batch and uses it in the ``aten.view`` layer. This requires shape tensors support
which would be a part of our next release.

Workaround (BERT static compilation example)
------------------------------------------

In the case where you encounter the issues mentioned in the **Limitations** section,
you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs,
we can pad them accordingly. This is only a workaround until we address the limitations.

.. code-block:: python

import torch
import torch_tensorrt
from transformers.utils.fx import symbolic_trace as transformers_trace

model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()

# Input sequence length is 20.
input1 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda")

model = transformers_trace(model, input_names=["input_ids", "attention_mask"]).eval().cuda()
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec)
model_outputs = model(input, input2)

# If you have a sequence of length 14, pad 6 zero tokens and run inference
# or recompile for sequence length of 14.
input1 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda")
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec)
model_outputs = model(input, input2)


Dynamic shapes with ir=torch_compile
------------------------------------

``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend
configured to Tensorrt. In the case of ``ir=torch_compile``, users have to recompile for different input shapes.
In the future, we plan to explore the option of compiling with dynamic shapes in the first execution of the model.

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224), dtype=float32)
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs)
# Compilation happens when you call the model
trt_gm(inputs)

# Recompilation happens with modified batch size
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs_bs2)










15 changes: 14 additions & 1 deletion py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class _ShapeMode(Enum):
low_tensor_domain_incl: float = 0.0
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
torch_dtype: torch.dtype = torch.float32
torch_tensor: torch.Tensor = None

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""__init__ Method for torch_tensorrt.Input
Expand Down Expand Up @@ -171,6 +172,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

self.tensor_domain = Input._parse_tensor_domain(domain)

if "torch_tensor" in kwargs:
self.torch_tensor = kwargs["torch_tensor"]
else:
if self.shape_mode == Input._ShapeMode.DYNAMIC:
self.torch_tensor = self.example_tensor("opt_shape")
else:
self.torch_tensor = self.example_tensor()
Comment on lines +175 to +181
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if "torch_tensor" is provided, but the shape is also Dynamic - would it be an issue if this Tensor had the shape of min or max instead of opt? This could be refactored to override or validate the specified tensor if the shape is dynamic.

Copy link
Collaborator Author

@peri044 peri044 Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be an issue if this Tensor had the shape of min or max instead of opt?

So the usecase you are saying is if the user provides
x = torch_tensorrt.Input(min_shape=<>, opt_shape=<>, max_shape=<>, torch_tensor=<random_tensor>)
This wouldn't be an issue, as we ignore torch_tensor in this case of dynamic shape compilation.

torch_inputs = get_torch_inputs(inputs, device, mode)
would generate input tensors for all the modes.
We could maybe pass a warning to users that torch_tensor is not being used in such cases.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - I believe the feature ask from #2323 is to allow providing tensors for each of min, opt, and max. Could torch_tensor be instead either a list/tuple of 3 tensors (min, opt, max) or a single tensor?


def __str__(self) -> str:
if self.shape_mode == Input._ShapeMode.STATIC:
return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format(
Expand Down Expand Up @@ -220,6 +229,8 @@ def _parse_dtype(dtype: Any) -> _enums.dtype:
return _enums.dtype.half
elif dtype == torch.float:
return _enums.dtype.float
elif dtype == torch.float64:
return _enums.dtype.double
elif dtype == torch.bool:
return _enums.dtype.bool
else:
Expand Down Expand Up @@ -249,6 +260,8 @@ def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
return torch.float
elif dtype == _enums.dtype.bool:
return torch.bool
elif dtype == _enums.dtype.double:
return torch.float64
else:
# Default torch_dtype used in FX path
return torch.float32
Expand Down Expand Up @@ -354,7 +367,7 @@ def from_tensor(
)
else torch.channels_last
)
return cls(shape=t.shape, dtype=t.dtype, format=frmt)
return cls(shape=t.shape, dtype=t.dtype, format=frmt, torch_tensor=t)

@classmethod
def from_tensors(
Expand Down
17 changes: 9 additions & 8 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,20 @@ def compile(
)
return compiled_fx_module
elif target_ir == _IRType.dynamo:
# Prepare torch and torchtrt inputs
import collections.abc

from torch_tensorrt import Device
from torch_tensorrt.dynamo.utils import prepare_inputs, to_torch_device
from torch_tensorrt.dynamo.utils import prepare_inputs

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]
device = kwargs.get("device", Device._current_device())
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
if not isinstance(input_list, collections.abc.Sequence):
input_list = [input_list]

# Export the module
torchtrt_inputs = prepare_inputs(input_list)
module = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
compiled_aten_module: torch.fx.GraphModule = dynamo_compile(
module,
inputs=input_list,
inputs=torchtrt_inputs,
enabled_precisions=enabled_precisions_set,
**kwargs,
)
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ std::string to_str(DataType value) {
return "Float";
case DataType::kLong:
return "Long";
case DataType::kDouble:
return "Double";
default:
return "Unknown data type";
}
Expand All @@ -33,6 +35,8 @@ nvinfer1::DataType toTRTDataType(DataType value) {
return nvinfer1::DataType::kINT32;
case DataType::kLong:
return nvinfer1::DataType::kINT32;
case DataType::kDouble:
return nvinfer1::DataType::kFLOAT;
case DataType::kBool:
return nvinfer1::DataType::kBOOL;
case DataType::kFloat:
Expand All @@ -58,6 +62,8 @@ at::ScalarType toAtenDataType(DataType value) {
return at::kBool;
case DataType::kFloat:
return at::kFloat;
case DataType::kDouble:
return at::kDouble;
case DataType::kUnknown:
return at::kFloat;
default:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace pyapi {
return static_cast<int64_t>(field_name); \
}

enum class DataType : int8_t { kLong, kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
enum class DataType : int8_t { kLong, kDouble, kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
std::string to_str(DataType value);
nvinfer1::DataType toTRTDataType(DataType value);
at::ScalarType toAtenDataType(DataType value);
Expand Down
Loading