Skip to content

feat: Add support for device compilation setting #2190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions py/torch_tensorrt/_Device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

import warnings

import torch
from torch_tensorrt import logging

# from torch_tensorrt import _enums
import tensorrt as trt
import torch
from torch_tensorrt import logging

try:
from torch_tensorrt import _C
Expand Down Expand Up @@ -120,6 +119,9 @@ def __str__(self) -> str:
)
)

def __repr__(self) -> str:
return self.__str__()

def _to_internal(self) -> _C.Device:
internal_dev = _C.Device()
if self.device_type == trt.DeviceType.GPU:
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,12 @@ def compile(
import collections.abc

from torch_tensorrt import Device
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
from torch_tensorrt.dynamo.utils import prepare_inputs, to_torch_device

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]
device = kwargs.get("device", Device._current_device())
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
compiled_aten_module: torch.fx.GraphModule = dynamo_compile(
module,
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from torch_tensorrt._Device import Device

PRECISION = torch.float32
DEBUG = False
DEVICE = None
WORKSPACE_SIZE = 0
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
Expand All @@ -12,3 +14,7 @@
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False


def default_device() -> Device:
return Device(gpu_id=torch.cuda.current_device())
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Set

import torch
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo._defaults import (
DEBUG,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
Expand All @@ -15,6 +16,7 @@
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
default_device,
)


Expand Down Expand Up @@ -54,3 +56,4 @@ class CompilationSettings:
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device is now populated with a default factory once the CompilationSettings object gets instantiated.

16 changes: 12 additions & 4 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import collections.abc
import logging
from typing import Any, List, Optional, Sequence, Set, Tuple
from typing import Any, List, Optional, Sequence, Set, Tuple, Union

import torch
import torch_tensorrt
Expand All @@ -13,6 +13,7 @@
from torch_tensorrt.dynamo import CompilationSettings, partitioning
from torch_tensorrt.dynamo._defaults import (
DEBUG,
DEVICE,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
Expand All @@ -29,7 +30,11 @@
convert_module,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
from torch_tensorrt.dynamo.utils import (
prepare_inputs,
to_torch_device,
to_torch_tensorrt_device,
)

logger = logging.getLogger(__name__)

Expand All @@ -38,7 +43,7 @@ def compile(
gm: Any,
inputs: Any,
*,
device: Device = Device._current_device(),
device: Optional[Union[Device, torch.device, str]] = DEVICE,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device can now be None, which will take the Torch default device.

disable_tf32: bool = False,
sparse_weights: bool = False,
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
Expand Down Expand Up @@ -82,7 +87,9 @@ def compile(
if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]

_, torch_inputs = prepare_inputs(inputs, prepare_device(device))
device = to_torch_tensorrt_device(device)

_, torch_inputs = prepare_inputs(inputs, to_torch_device(device))

if (
torch.float16 in enabled_precisions
Expand All @@ -105,6 +112,7 @@ def compile(
compilation_options = {
"precision": precision,
"debug": debug,
"device": device,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import io
from typing import Sequence

import tensorrt as trt
import torch
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule

import tensorrt as trt


def convert_module(
module: torch.fx.GraphModule,
Expand Down Expand Up @@ -72,4 +71,5 @@ def convert_module(
name=name,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
target_device=settings.device,
)
52 changes: 42 additions & 10 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from dataclasses import fields, replace
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, Optional, Sequence, Union

import torch
import torch_tensorrt
Expand Down Expand Up @@ -116,23 +116,45 @@ def prepare_inputs(
)


def prepare_device(device: Device | torch.device) -> torch.device:
_device: torch.device
def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch.device:
"""Cast a device-type to torch.device

Returns the corresponding torch.device
"""
if isinstance(device, Device):
if device.gpu_id != -1:
_device = torch.device(device.gpu_id)
return torch.device(device.gpu_id)
else:
raise ValueError("Invalid GPU ID provided for the CUDA device provided")

elif isinstance(device, torch.device):
_device = device
return device

elif device is None:
return torch.device(torch.cuda.current_device())

else:
raise ValueError(
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
)
return torch.device(device)

return _device

def to_torch_tensorrt_device(
device: Optional[Union[Device, torch.device, str]]
) -> Device:
"""Cast a device-type to torch_tensorrt.Device

Returns the corresponding torch_tensorrt.Device
"""
if isinstance(device, Device):
return device

elif isinstance(device, torch.device):
return Device(gpu_id=device.index)

elif device is None:
return Device(gpu_id=torch.cuda.current_device())
Comment on lines +153 to +154
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the device is None, we use the default Torch context device


else:
return Device(device)


def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
Expand Down Expand Up @@ -184,7 +206,17 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
# Parse input runtime specification
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)

logger.info("Compilation Settings: %s\n", settings)
# Ensure device is a torch_tensorrt Device
settings.device = to_torch_tensorrt_device(settings.device)

# Check and update device settings
if "device" not in kwargs:
logger.info(
f"Device not specified, using Torch default current device - cuda:{settings.device.gpu_id}. "
"If this is incorrect, please specify an input device, via the device keyword."
)

logger.info(f"Compiling with Settings:\n{settings}")

return settings

Expand Down
51 changes: 43 additions & 8 deletions tests/py/dynamo/backend/test_compiler_utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,61 @@
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
from utils import same_output_format
import torch_tensorrt
import unittest

import torch
import torch_tensorrt
from torch_tensorrt.dynamo.utils import (
prepare_inputs,
to_torch_device,
to_torch_tensorrt_device,
)
from utils import same_output_format


class TestPrepareDevice(unittest.TestCase):
def test_prepare_cuda_device(self):
class TestToTorchDevice(unittest.TestCase):
def test_cast_cuda_device(self):
gpu_id = 0
device = torch.device(f"cuda:{gpu_id}")
prepared_device = prepare_device(device)
prepared_device = to_torch_device(device)
self.assertTrue(isinstance(prepared_device, torch.device))
self.assertTrue(prepared_device.index == gpu_id)

def test_prepare_trt_device(self):
def test_cast_trt_device(self):
gpu_id = 4
device = torch_tensorrt.Device(gpu_id=gpu_id)
prepared_device = prepare_device(device)
prepared_device = to_torch_device(device)
self.assertTrue(isinstance(prepared_device, torch.device))
self.assertTrue(prepared_device.index == gpu_id)

def test_cast_str_device(self):
gpu_id = 2
device = f"cuda:{2}"
prepared_device = to_torch_device(device)
self.assertTrue(isinstance(prepared_device, torch.device))
self.assertTrue(prepared_device.index == gpu_id)


class TestToTorchTRTDevice(unittest.TestCase):
def test_cast_cuda_device(self):
gpu_id = 0
device = torch.device(f"cuda:{gpu_id}")
prepared_device = to_torch_tensorrt_device(device)
self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device))
self.assertTrue(prepared_device.gpu_id == gpu_id)

def test_cast_trt_device(self):
gpu_id = 4
device = torch_tensorrt.Device(gpu_id=gpu_id)
prepared_device = to_torch_tensorrt_device(device)
self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device))
self.assertTrue(prepared_device.gpu_id == gpu_id)

def test_cast_str_device(self):
gpu_id = 2
device = f"cuda:{2}"
prepared_device = to_torch_tensorrt_device(device)
self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device))
self.assertTrue(prepared_device.gpu_id == gpu_id)


class TestPrepareInputs(unittest.TestCase):
def test_prepare_single_tensor_input(self):
inputs = [torch.ones((4, 4))]
Expand Down