diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index a9dd0d9e9b..00d2eb61af 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -342,6 +342,7 @@ def forward(self, model, argv, kwargs): self._build_and_save(model, build_args) # This will reassign input_names from the engine self._load_engine() + assert self.engine is not None except Exception as e: if self.fallback: self.logger.info(f"Failed to build engine: {e}") @@ -403,8 +404,10 @@ def _onnx_to_trt(self, onnx_path): build_args = self.build_args.copy() build_args["tf32"] = self.precision != "fp32" - build_args["fp16"] = self.precision == "fp16" - build_args["bf16"] = self.precision == "bf16" + if self.precision == "fp16": + build_args["fp16"] = True + elif self.precision == "bf16": + build_args["bf16"] = True self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) @@ -502,6 +505,7 @@ def trt_compile( ) -> torch.nn.Module: """ Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x Args: model: module to patch with TrtCompiler object. base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 21125d203f..2f9db8f0c2 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -20,10 +20,10 @@ from monai.handlers import TrtHandler from monai.networks import trt_compile from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 -from monai.utils import optional_import +from monai.utils import min_version, optional_import from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows -trt, trt_imported = optional_import("tensorrt") +trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) polygraphy, polygraphy_imported = optional_import("polygraphy") build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")