-
Notifications
You must be signed in to change notification settings - Fork 369
feat: Autocast #3878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zewenli98
wants to merge
7
commits into
main
Choose a base branch
from
autocast
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,015
−6
Open
feat: Autocast #3878
Changes from 3 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
eac8809
implement autocast
zewenli98 f6c7c7c
fix bug
zewenli98 f7d8068
add arg enable_autocast
zewenli98 e15ce94
change names of API and support for user specified node names
zewenli98 94757d2
support dataloader for calibration
zewenli98 4bf12e7
fix comments
zewenli98 0a62149
optimize Cast insertion logic, fix io dtype issue and comments, and a…
zewenli98 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch_tensorrt | ||
| import torchvision | ||
|
|
||
|
|
||
| class MyModule(torch.nn.Module): | ||
| def forward(self, a_float32, b_float32, c_float32, d_float32): | ||
| with torch.autocast(device_type="cuda"): | ||
| e_float16 = torch.mm(a_float32, b_float32) | ||
| with torch.autocast(device_type="cuda", enabled=False): | ||
| # Calls e_float16.float() to ensure float32 execution | ||
| # (necessary because e_float16 was created in an autocasted region) | ||
| f_float32 = torch.mm(c_float32, e_float16.float()) | ||
|
|
||
| # No manual casts are required when re-entering the autocast-enabled region. | ||
| # torch.mm again runs in float16 and produces float16 output, regardless of input types. | ||
| g_float16 = torch.mm(d_float32, f_float32) | ||
| return g_float16 | ||
|
|
||
|
|
||
| class AutocastExample(nn.Module): | ||
| def __init__(self): | ||
| super(AutocastExample, self).__init__() | ||
| self.conv1 = nn.Conv2d( | ||
| in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 | ||
| ) | ||
| self.relu1 = nn.ReLU() | ||
| self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
| self.conv2 = nn.Conv2d( | ||
| in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 | ||
| ) | ||
| self.relu2 = nn.ReLU() | ||
| self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
| self.flatten = nn.Flatten() | ||
| self.fc1 = nn.Linear(16 * 8 * 8, 10) | ||
|
|
||
| def forward(self, x, y): | ||
| out = self.pool1(self.relu1(self.conv1(x))) # fp16 | ||
| x = self.pool2(self.relu2(self.conv2(out))) # fp16 | ||
| x = self.flatten(x) | ||
| with torch.autocast(x.device.type, enabled=True, dtype=torch.float32): | ||
| x = self.fc1(x) # fp32 | ||
| with torch.autocast(x.device.type, enabled=False): | ||
| x = torch.sub(x.half(), y) # fp16 | ||
| out2 = torch.add(x, x) # fp16 | ||
| with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): | ||
| out2 = torch.log(out2) # fp32 | ||
| return x, out, out2 | ||
zewenli98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class MyResNet18Wrapper(torch.nn.Module): | ||
| def __init__(self, num_classes=1000, pretrained=True): | ||
| super(MyResNet18Wrapper, self).__init__() | ||
| self.resnet = torchvision.models.resnet18( | ||
| num_classes=num_classes, weights="IMAGENET1K_V1" if pretrained else None | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| x = self.resnet(x) | ||
| return x | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| # model = MyModule().cuda().eval() | ||
| # inputs = (torch.randn((8, 8), device="cuda"), | ||
| # torch.randn((8, 8), device="cuda"), | ||
| # torch.randn((8, 8), device="cuda"), | ||
| # torch.randn((8, 8), device="cuda"),) | ||
|
|
||
| # model = AutocastExample().cuda().eval() | ||
| # inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"), | ||
| # torch.randn((1,), dtype=torch.float16, device="cuda"),) | ||
zewenli98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| model = MyResNet18Wrapper().cuda().eval() | ||
| inputs = (torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda"),) | ||
|
|
||
| ep = torch.export.export(model, inputs) | ||
|
|
||
| with torch_tensorrt.dynamo.Debugger( | ||
| "graphs", | ||
| logging_dir=".", | ||
| engine_builder_monitor=False, | ||
| ): | ||
| trt_mod = torch_tensorrt.compile( | ||
| ep.module(), | ||
| arg_inputs=inputs, | ||
| min_block_size=1, | ||
| use_python_runtime=True, | ||
| ##### weak typing ##### | ||
| # use_explicit_typing=False, | ||
| # enabled_precisions={torch.float16}, | ||
| ##### strong typing + autocast ##### | ||
| use_explicit_typing=True, | ||
| enable_autocast=True, | ||
| low_precision_type=torch.float16, | ||
| # nodes_to_exclude={"^conv2d$"}, | ||
| targets_to_exclude={}, | ||
| data_max=512, | ||
| max_depth_of_reduction=None, | ||
| ) | ||
|
|
||
| trt_out = trt_mod(*inputs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,7 +141,7 @@ def cross_compile_for_windows( | |
| disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas | ||
| assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False | ||
| sparse_weights (bool): Enable sparsity for convolution and fully connected layers. | ||
| enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels | ||
| enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels | ||
| capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels | ||
| num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels | ||
| workspace_size (int): Maximum size of workspace given to TensorRT | ||
|
|
@@ -434,6 +434,14 @@ def compile( | |
| l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, | ||
| offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, | ||
| use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, | ||
| enable_autocast: bool = _defaults.ENABLE_AUTOCAST, | ||
| low_precision_type: Optional[ | ||
| Union[torch.dtype, dtype] | ||
| ] = _defaults.LOW_PRECISION_TYPE, | ||
| nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE, | ||
| targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE, | ||
| data_max: float = _defaults.DATA_MAX, | ||
| max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION, | ||
|
||
| **kwargs: Any, | ||
| ) -> torch.fx.GraphModule: | ||
| """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT | ||
|
|
@@ -511,6 +519,12 @@ def compile( | |
| l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). | ||
| offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. | ||
| use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model | ||
| enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. | ||
| low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. | ||
| nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. | ||
| targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. | ||
| data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. | ||
| max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. | ||
| **kwargs: Any, | ||
| Returns: | ||
| torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT | ||
|
|
@@ -584,6 +598,10 @@ def compile( | |
| "\nThis feature is unimplemented in Torch-TRT Dynamo currently." | ||
| ) | ||
|
|
||
| if enable_autocast: | ||
| use_explicit_typing = True | ||
| logger.debug("Autocast is enabled, setting use_explicit_typing to True.") | ||
|
|
||
| if use_explicit_typing: | ||
| if len(enabled_precisions) != 1 or not any( | ||
| x in enabled_precisions | ||
|
|
@@ -593,6 +611,19 @@ def compile( | |
| f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" | ||
| ) | ||
|
|
||
| if low_precision_type is not None: | ||
| if not isinstance(low_precision_type, (torch.dtype, dtype)): | ||
| raise ValueError( | ||
| f"low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(low_precision_type)}" | ||
| ) | ||
| if low_precision_type not in { | ||
| torch.float16, | ||
| torch.bfloat16, | ||
| } and low_precision_type not in {dtype.f16, dtype.bf16}: | ||
| raise ValueError( | ||
| f"low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {low_precision_type}" | ||
| ) | ||
|
|
||
| if use_fp32_acc: | ||
| logger.debug( | ||
| "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \ | ||
|
|
@@ -622,6 +653,38 @@ def compile( | |
| if not isinstance(arg_inputs, collections.abc.Sequence): | ||
| arg_inputs = [arg_inputs] # type: ignore | ||
|
|
||
| # save intermediate outputs of each node for Autocast | ||
| intermediate_node_outputs = {} | ||
| if not use_explicit_typing: | ||
|
|
||
| class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc] | ||
| """Dump intermediate outputs of each node""" | ||
|
|
||
| def run_node(self, n: torch.fx.Node) -> Any: | ||
| if ( | ||
| n.op == "call_function" | ||
| and n.target != torch.ops.higher_order.wrap_with_autocast | ||
| ): | ||
| out = super().run_node(n) | ||
| if not isinstance(out, torch.Tensor): | ||
| raise ValueError( | ||
| f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." | ||
| ) | ||
| intermediate_node_outputs[n.name] = out | ||
| return out | ||
| return super().run_node(n) | ||
|
|
||
| def _materialize(x: Input | torch.Tensor) -> torch.Tensor: | ||
| """Materialize an Input object to a tensor""" | ||
| if isinstance(x, Input): | ||
| return x.torch_tensor | ||
| return x | ||
|
|
||
| with torch.no_grad(): | ||
| mat_args = tuple(_materialize(a) for a in arg_inputs) | ||
| mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()} | ||
| DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs) | ||
zewenli98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Prepare torch_trt inputs | ||
| trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) | ||
| trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) | ||
|
|
@@ -680,6 +743,13 @@ def compile( | |
| "l2_limit_for_tiling": l2_limit_for_tiling, | ||
| "offload_module_to_cpu": offload_module_to_cpu, | ||
| "use_distributed_mode_trace": use_distributed_mode_trace, | ||
| "enable_autocast": enable_autocast, | ||
| "low_precision_type": low_precision_type, | ||
| "nodes_to_exclude": nodes_to_exclude, | ||
| "targets_to_exclude": targets_to_exclude, | ||
| "data_max": data_max, | ||
| "max_depth_of_reduction": max_depth_of_reduction, | ||
| "intermediate_node_outputs": intermediate_node_outputs, | ||
zewenli98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| settings = CompilationSettings(**compilation_options) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this not necessary now ?