Skip to content

Cherrypick #3481 and #3445 #3498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path)
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
_FX_FE_AVAIL = True
_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.13")
_REFIT_AVAIL = True

ENABLED_FEATURES = FeatureSet(
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
Expand Down
159 changes: 66 additions & 93 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tensorrt as trt
import torch
from torch.export import ExportedProgram
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import needs_refit
from torch_tensorrt._Input import Input
Expand Down Expand Up @@ -61,26 +62,13 @@ def construct_refit_mapping(
Returns:
Mapping from weight name in TensorRT to actual weight value in np.ndarray
"""
MODULE_MAP = {
"SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]),
"CONVOLUTION": (
trt.IConvolutionLayer,
[("kernel", "KERNEL"), ("bias", "BIAS")],
),
"DECONVOLUTION": (
trt.IDeconvolutionLayer,
[("kernel", "KERNEL"), ("bias", "BIAS")],
),
"CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]),
}

output_dtypes = infer_module_output_dtypes(
module,
truncate_double=settings.truncate_double,
)

# Use Interpreter
weight_map = {}
interpreter = TRTInterpreter(
module,
inputs,
Expand All @@ -89,24 +77,8 @@ def construct_refit_mapping(
compilation_settings=settings,
)
interpreter._construct_trt_network_def()
net = interpreter.ctx.net
for i in range(net.num_layers):
layer = net[i]
layer_type: str = layer.type.name
if layer_type in MODULE_MAP:
# Cast the parent class to child class to access attributes
# For example: ILayer does not have ILayer.kernel/ILayer.bias
# So we cast it to IConvolutionLayer and access the attributes
layer.__class__ = MODULE_MAP[layer_type][0]
for weight_type, weight_name in MODULE_MAP[layer_type][1]:
weight = layer.__getattribute__(weight_type).copy()
weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType)
weight_map[f"{layer.name} {weight_name}"] = (
weight,
weight_dtype,
)

return weight_map
return interpreter.ctx.mapping


@needs_refit
Expand All @@ -117,13 +89,12 @@ def construct_refit_mapping_from_weight_name_map(
) -> dict[Any, Any]:
engine_weight_map = {}
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)

if sd_weight_name not in state_dict:
# If weights is not in sd, we can leave it unchanged
continue
else:
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
to_torch_device(settings.device)
)
Expand Down Expand Up @@ -152,71 +123,73 @@ def _refit_single_trt_engine_with_gm(
Refit a TensorRT Engine in place
"""

refitted = set()
torch_device = get_model_device(new_gm)
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

if weight_name_map:
# Get the refitting mapping
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)
with unset_fake_temporarily():
refitted = set()
torch_device = get_model_device(new_gm)
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

if weight_name_map:
# Get the refitting mapping
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)

constant_mapping: dict[str, Any] = weight_name_map.pop(
"constant_mapping", {}
) # type: ignore
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict(), settings
)
constant_mapping_with_type = {}

for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
constant_mapping: dict[str, Any] = weight_name_map.pop(
"constant_mapping", {}
) # type: ignore
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict(), settings
)
constant_mapping_with_type = {}

for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
)

mapping.update(constant_mapping_with_type)
mapping.update(constant_mapping_with_type)

for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"
for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"

else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
raise AssertionError("Refitting failed.")
else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight = mapping[layer_name]
trt_dtype = dtype._from(weight.dtype).to(trt.DataType)
trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
raise AssertionError("Refitting failed.")


@needs_refit
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field

import numpy as np
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.fx.types import TRTNetwork

Expand All @@ -19,3 +20,4 @@ class ConversionContext:
default_factory=CompilationSettings
)
requires_output_allocator: bool = False
mapping: dict[str, np.array] = field(default_factory=dict)
54 changes: 24 additions & 30 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorrt as trt
import torch
import torch.fx
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._python_dispatch import _disable_current_modes
Expand All @@ -42,6 +43,7 @@
get_node_io,
get_node_name,
get_trt_tensor,
to_torch,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
from torch_tensorrt.fx.observer import Observer
Expand Down Expand Up @@ -410,27 +412,29 @@ def find_weight(
np_map: the map from weight name to np values in INetworkDefinition
state_dict: state of the graph module
"""
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""
with unset_fake_temporarily():
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""

@staticmethod
def check_weight_equal(
sd_weight: torch.tensor,
network_weight: Union[torch.Tensor, np.ndarray],
device: torch.device,
) -> Any:
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
)
except Exception:
return torch.all(sd_weight == network_weight)
with unset_fake_temporarily():
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
)
except Exception:
return torch.all(sd_weight == network_weight)

@needs_refit
def _save_weight_mapping(self) -> None:
Expand Down Expand Up @@ -495,19 +499,15 @@ def _save_weight_mapping(self) -> None:
for k, v in self.module.state_dict().items()
}
weight_name_map: dict[str, Any] = {}
np_map = {}
constant_mapping = {}
np_map = self.ctx.mapping
constant_mapping = {k: v for k, v in np_map.items() if v.size == 1}
net = self.ctx.net
for i in range(net.num_layers):
layer = net[i]
layer_type: str = layer.type.name
if layer_type in MODULE_MAP:
layer.__class__ = MODULE_MAP[layer_type][0]
# Name mapping
for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]:
weight = layer.__getattribute__(weight_type).copy()
if weight.size == 0:
continue
engine_weight_name = f"{layer.name} {weight_name}"
# Infer the corresponding weight name(s) in state_dict
sd_weight_name_list = (
Expand Down Expand Up @@ -535,17 +535,15 @@ def _save_weight_mapping(self) -> None:
elif "bias" in suffix:
sd_weight_name = f"{sd_weight_name}.bias"
else:
# Save the constant weights for future fast refit
sd_weight_name = f"{sd_weight_name}.unknown"
constant_mapping[engine_weight_name] = weight
elif layer_type == "SCALE":
# Batch norm needs all weights to calculate scale and shift
sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr]
else:
sd_weight_name = f"{sd_weight_name}.{torch_attr}"

weight_name_map[engine_weight_name] = sd_weight_name
np_map[engine_weight_name] = weight
if engine_weight_name in np_map:
weight_name_map[engine_weight_name] = sd_weight_name

# Stage 2: Value mapping
for engine_weight_name, sd_weight_name in weight_name_map.items():
Expand Down Expand Up @@ -887,19 +885,15 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
return converter(self.ctx, target, args, kwargs, self._cur_node_name)

def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
with _disable_current_modes():
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy

with _disable_current_modes(), unset_fake_temporarily():
frozen_attr = self.fetch_attr(target)

if isinstance(frozen_attr, torch.nn.Parameter):
constant_tensor = frozen_attr.data
else:
constant_tensor = frozen_attr

network_constant = to_numpy(constant_tensor)

return network_constant
return to_torch(constant_tensor)

def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
assert isinstance(target, str)
Expand Down
Loading
Loading