File tree Expand file tree Collapse file tree 3 files changed +11
-12
lines changed Expand file tree Collapse file tree 3 files changed +11
-12
lines changed Original file line number Diff line number Diff line change 1919import numpy as np
2020import tensorrt as trt
2121import torch
22- import torch_tensorrt .dynamo .conversion .impl as impl
2322from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
2423from torch .fx .node import Argument , Target
2524from torch .fx .passes .shape_prop import TensorMetadata
25+
26+ import torch_tensorrt .dynamo .conversion .impl as impl
2627from torch_tensorrt import _enums
2728from torch_tensorrt .dynamo ._settings import CompilationSettings
2829from torch_tensorrt .dynamo ._SourceIR import SourceIR
@@ -152,9 +153,9 @@ def cast_trt_tensor(
152153) -> TRTTensor :
153154 """Given a TRT Tensor, convert that Tensor to the specified dtype
154155
155- Adds an Identity layer to the network which performs the conversion
156- if the input's dtype is different from the cast type. Otherwise returns
157- input unchanged
156+ Adds a Cast layer to the network to convert the input tensor to the specified dtype.
157+ If the input tensor already has the desired dtype, it is returned unchanged.
158+ Otherwise, a Cast layer is added to perform the conversion
158159
159160 Args:
160161 ctx (ConversionContext): A ConversionContext containing the TensorRT network
Original file line number Diff line number Diff line change 55import numpy as np
66import tensorrt as trt
77from torch .fx .node import Target
8+
89from torch_tensorrt .dynamo ._SourceIR import SourceIR
910from torch_tensorrt .dynamo .conversion import impl
1011from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
1314 flatten_dims ,
1415 get_positive_dim ,
1516 get_trt_tensor ,
17+ has_dynamic_shape ,
18+ prepend_ones ,
19+ set_layer_name ,
1620)
1721from torch_tensorrt .dynamo .conversion .impl .cat import cat
1822from torch_tensorrt .dynamo .conversion .impl .elementwise import floor_divide
2327from torch_tensorrt .dynamo .conversion .impl .shape import shape as get_shape
2428from torch_tensorrt .dynamo .conversion .impl .slice .base import slice
2529from torch_tensorrt .dynamo .utils import DYNAMIC_DIM
26- from torch_tensorrt .fx .converters .converter_utils import (
27- has_dynamic_shape ,
28- prepend_ones ,
29- set_layer_name ,
30- )
3130from torch_tensorrt .fx .types import Shape , TRTTensor
3231
3332
@@ -230,7 +229,7 @@ def expand(
230229 # If the rank of the input tensor is less than the shape's rank, pad with ones
231230 if initial_tensor_rank < shape_rank :
232231 input_t = prepend_ones (
233- ctx . net ,
232+ ctx ,
234233 input_t ,
235234 name + "_expand_broadcast" ,
236235 shape_rank - initial_tensor_rank ,
Original file line number Diff line number Diff line change @@ -909,7 +909,6 @@ def type_cast(
909909 """
910910 This function helps to cast the input type to cast_type
911911 """
912- layer_i = network .add_identity (input )
913- layer_i .set_output_type (0 , cast_type )
912+ layer_i = network .add_cast (input , cast_type )
914913 set_layer_name (layer_i , target , f"{ name } _dtype_change" )
915914 return layer_i .get_output (0 )
You can’t perform that action at this time.
0 commit comments