diff --git a/torch/ao/quantization/fx/backend_config/tensorrt.py b/torch/ao/quantization/fx/backend_config/tensorrt.py index 6504ce7a93319..6237a6491e116 100644 --- a/torch/ao/quantization/fx/backend_config/tensorrt.py +++ b/torch/ao/quantization/fx/backend_config/tensorrt.py @@ -4,7 +4,7 @@ import torch.nn.intrinsic as nni import torch.nn.intrinsic.qat as nniqat -from ...fuser_method_mappings import reverse2 +from ...fuser_method_mappings import reverse_sequential_wrapper2 def get_tensorrt_backend_config_dict(): """ Get the backend config dictionary for tensorrt backend @@ -63,7 +63,7 @@ def get_tensorrt_backend_config_dict(): "dtype_configs": [ weighted_op_qint8_dtype_config, ], - "fuser_method": reverse2(nni.LinearReLU), + "fuser_method": reverse_sequential_wrapper2(nni.LinearReLU), } linear_relu_mf_config = { "pattern": (torch.nn.functional.relu, torch.nn.Linear), @@ -71,7 +71,7 @@ def get_tensorrt_backend_config_dict(): "dtype_configs": [ weighted_op_qint8_dtype_config, ], - "fuser_method": reverse2(nni.LinearReLU), + "fuser_method": reverse_sequential_wrapper2(nni.LinearReLU), } linear_relu_fused_config = { @@ -156,7 +156,7 @@ def get_tensorrt_backend_config_dict(): "dtype_configs": [ weighted_op_qint8_dtype_config, ], - "fuser_method": reverse2(nni.ConvReLU2d), + "fuser_method": reverse_sequential_wrapper2(nni.ConvReLU2d), } conv2d_relu_mm_config = { "pattern": (torch.nn.ReLU, torch.nn.Conv2d), @@ -164,7 +164,7 @@ def get_tensorrt_backend_config_dict(): "dtype_configs": [ weighted_op_qint8_dtype_config, ], - "fuser_method": reverse2(nni.ConvReLU2d), + "fuser_method": reverse_sequential_wrapper2(nni.ConvReLU2d), } addmm_config = { "pattern": torch.addmm,