diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 4843ec0145..5d6d27e4ad 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -307,7 +307,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: def TensorRTCompileSpec( inputs: Optional[List[torch.Tensor | Input]] = None, input_signature: Optional[Any] = None, - device: torch.device | Device = Device._current_device(), + device: Optional[torch.device | Device] = None, disable_tf32: bool = False, sparse_weights: bool = False, enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, @@ -365,7 +365,7 @@ def TensorRTCompileSpec( compile_spec = { "inputs": inputs if inputs is not None else [], # "input_signature": input_signature, - "device": device, + "device": Device._current_device() if device is None else device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers. "enabled_precisions": (