-
Notifications
You must be signed in to change notification settings - Fork 364
feat: Add support for device compilation setting #2190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
gs-olive
merged 1 commit into
pytorch:main
from
gs-olive:device_setting_checking_standardization
Aug 25, 2023
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import collections.abc | ||
import logging | ||
from typing import Any, List, Optional, Sequence, Set, Tuple | ||
from typing import Any, List, Optional, Sequence, Set, Tuple, Union | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
@@ -13,6 +13,7 @@ | |
from torch_tensorrt.dynamo import CompilationSettings, partitioning | ||
from torch_tensorrt.dynamo._defaults import ( | ||
DEBUG, | ||
DEVICE, | ||
ENABLE_EXPERIMENTAL_DECOMPOSITIONS, | ||
MAX_AUX_STREAMS, | ||
MIN_BLOCK_SIZE, | ||
|
@@ -29,7 +30,11 @@ | |
convert_module, | ||
repair_long_or_double_inputs, | ||
) | ||
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs | ||
from torch_tensorrt.dynamo.utils import ( | ||
prepare_inputs, | ||
to_torch_device, | ||
to_torch_tensorrt_device, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -38,7 +43,7 @@ def compile( | |
gm: Any, | ||
inputs: Any, | ||
*, | ||
device: Device = Device._current_device(), | ||
device: Optional[Union[Device, torch.device, str]] = DEVICE, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Device can now be |
||
disable_tf32: bool = False, | ||
sparse_weights: bool = False, | ||
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,), | ||
|
@@ -82,7 +87,9 @@ def compile( | |
if not isinstance(inputs, collections.abc.Sequence): | ||
inputs = [inputs] | ||
|
||
_, torch_inputs = prepare_inputs(inputs, prepare_device(device)) | ||
device = to_torch_tensorrt_device(device) | ||
|
||
_, torch_inputs = prepare_inputs(inputs, to_torch_device(device)) | ||
|
||
if ( | ||
torch.float16 in enabled_precisions | ||
|
@@ -105,6 +112,7 @@ def compile( | |
compilation_options = { | ||
"precision": precision, | ||
"debug": debug, | ||
"device": device, | ||
"workspace_size": workspace_size, | ||
"min_block_size": min_block_size, | ||
"torch_executed_ops": torch_executed_ops | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import logging | ||
from dataclasses import fields, replace | ||
from typing import Any, Callable, Dict, Optional, Sequence | ||
from typing import Any, Callable, Dict, Optional, Sequence, Union | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
@@ -116,23 +116,45 @@ def prepare_inputs( | |
) | ||
|
||
|
||
def prepare_device(device: Device | torch.device) -> torch.device: | ||
_device: torch.device | ||
def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch.device: | ||
"""Cast a device-type to torch.device | ||
|
||
Returns the corresponding torch.device | ||
""" | ||
if isinstance(device, Device): | ||
if device.gpu_id != -1: | ||
_device = torch.device(device.gpu_id) | ||
return torch.device(device.gpu_id) | ||
else: | ||
raise ValueError("Invalid GPU ID provided for the CUDA device provided") | ||
|
||
elif isinstance(device, torch.device): | ||
_device = device | ||
return device | ||
|
||
elif device is None: | ||
return torch.device(torch.cuda.current_device()) | ||
|
||
else: | ||
raise ValueError( | ||
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" | ||
) | ||
return torch.device(device) | ||
|
||
return _device | ||
|
||
def to_torch_tensorrt_device( | ||
device: Optional[Union[Device, torch.device, str]] | ||
) -> Device: | ||
"""Cast a device-type to torch_tensorrt.Device | ||
|
||
Returns the corresponding torch_tensorrt.Device | ||
""" | ||
if isinstance(device, Device): | ||
return device | ||
|
||
elif isinstance(device, torch.device): | ||
return Device(gpu_id=device.index) | ||
|
||
elif device is None: | ||
return Device(gpu_id=torch.cuda.current_device()) | ||
Comment on lines
+153
to
+154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the device is |
||
|
||
else: | ||
return Device(device) | ||
|
||
|
||
def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: | ||
|
@@ -184,7 +206,17 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: | |
# Parse input runtime specification | ||
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) | ||
|
||
logger.info("Compilation Settings: %s\n", settings) | ||
# Ensure device is a torch_tensorrt Device | ||
settings.device = to_torch_tensorrt_device(settings.device) | ||
|
||
# Check and update device settings | ||
if "device" not in kwargs: | ||
logger.info( | ||
f"Device not specified, using Torch default current device - cuda:{settings.device.gpu_id}. " | ||
"If this is incorrect, please specify an input device, via the device keyword." | ||
) | ||
|
||
logger.info(f"Compiling with Settings:\n{settings}") | ||
|
||
return settings | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Device is now populated with a default factory once the
CompilationSettings
object gets instantiated.