Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Prev Previous commit
Next Next commit
refactored, there are 3 types of engines
  • Loading branch information
zewenli98 committed Oct 15, 2024
commit 493f9810a43c2ea2b05dbd187db9cfef4eb56f1c
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def compile(
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -613,7 +613,7 @@ def convert_exported_program_to_serialized_trt_engine(
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class CompilationSettings:
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
refit_identical_engine_weights (bool): Whether to refit the engine with identical weights
strip_engine_weights (bool): Whether to strip the engine weights
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down
71 changes: 43 additions & 28 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,14 @@ def _populate_trt_builder_config(
if self.compilation_settings.disable_tf32:
builder_config.clear_flag(trt.BuilderFlag.TF32)

if not self.compilation_settings.immutable_weights:
if self.compilation_settings.immutable_weights:
# non-refittable engine
if self.compilation_settings.strip_engine_weights:
_LOGGER.warning(
"You cannot get a non-refittable engine with weights stripped. `strip_engine_weights` will be set to false and engine caching will be disabled."
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
# refittable engine
if version.parse(trt.__version__) >= version.parse("10.0"):
if self.compilation_settings.refit_identical_engine_weights:
builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL)
Expand All @@ -296,7 +303,8 @@ def _populate_trt_builder_config(
else:
builder_config.set_flag(trt.BuilderFlag.REFIT)

builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN)
if self.compilation_settings.strip_engine_weights:
builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN)

if strict_type_constraints:
builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
Expand Down Expand Up @@ -564,7 +572,7 @@ def run(
cached_data = self.engine_cache.check(hash_val)
if cached_data is not None: # hit the cache
(
serialized_engine,
weight_stripped_serialized_engine,
self._input_names,
self._output_names,
cached_engine_input_specs,
Expand Down Expand Up @@ -598,7 +606,9 @@ def run(
# refit the cached engine with the new graph module
if not self.compilation_settings.strip_engine_weights:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)
engine = runtime.deserialize_cuda_engine(
weight_stripped_serialized_engine
)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
Expand All @@ -620,6 +630,7 @@ def run(
serialized_engine = engine.serialize_with_config(
serialization_config
)
# As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
Expand Down Expand Up @@ -661,17 +672,43 @@ def run(
builder_config, self.compilation_settings.timing_cache_path
)

# refittable engine
if not self.compilation_settings.immutable_weights:
# Disable engine caching for non-refittable engines
# Engine caching only for refittable engine
if (
self.engine_cache is not None
and self.compilation_settings.cache_built_engines
):
# Cache the weight-stripped engine
if self.compilation_settings.strip_engine_weights:
weight_stripped_serialized_engine = serialized_engine
else:
# Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared
runtime = trt.Runtime(TRT_LOGGER)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we refit then strip the weights again? If refit is enabled shouldnt the builder always us a weight stripped engine?

engine = runtime.deserialize_cuda_engine(serialized_engine)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)

serialization_config = engine.create_serialization_config()
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
weight_stripped_serialized_engine = engine.serialize_with_config(
serialization_config
)

self.engine_cache.insert(
hash_val,
(
serialized_engine,
weight_stripped_serialized_engine,
self._input_names,
self._output_names,
self.input_specs,
Expand All @@ -680,28 +717,6 @@ def run(
),
)

if not self.compilation_settings.strip_engine_weights:
# Refit the engine with the original weights
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)

# Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialized_engine = engine.serialize_with_config(serialization_config)

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()
Expand Down
89 changes: 87 additions & 2 deletions tests/py/dynamo/models/test_weight_stripped_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,56 @@


class TestWeightStrippedEngine(TestCase):
def test_three_ways_to_compile(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
exp_program = torch.export.export(pyt_model, example_inputs)

settings = {
"use_python_runtime": False,
"enabled_precisions": {torch.float},
"debug": False,
"min_block_size": 1,
"strip_engine_weights": False,
"refit_identical_engine_weights": False,
}

# 1. Compile with torch_trt.dynamo.compile
gm1 = torch_trt.dynamo.compile(
exp_program,
example_inputs,
**settings,
)
gm1_output = gm1(*example_inputs)

# 2. Compile with torch_trt.compile using dynamo backend
gm2 = torch_trt.compile(
pyt_model, ir="dynamo", inputs=example_inputs, **settings
)
gm2_output = gm2(*example_inputs)

# 3. Compile with torch.compile using tensorrt backend
gm3 = torch.compile(
pyt_model,
backend="tensorrt",
options=settings,
)
gm3_output = gm3(*example_inputs)

pyt_model_output = pyt_model(*example_inputs)

assert torch.allclose(
pyt_model_output, gm1_output, 1e-2, 1e-2
), "gm1_output is not correct"

assert torch.allclose(
gm1_output, gm2_output, 1e-2, 1e-2
), "gm2_output is not correct"

assert torch.allclose(
gm2_output, gm3_output, 1e-2, 1e-2
), "gm3_output is not correct"

def test_weight_stripped_engine_sizes(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
Expand Down Expand Up @@ -67,8 +117,6 @@ def test_weight_stripped_engine_results(self):
enabled_precisions={torch.float},
debug=False,
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
strip_engine_weights=True,
refit_identical_engine_weights=False,
)
Expand Down Expand Up @@ -316,3 +364,40 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
times[0] > times[2],
msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms",
)

def test_different_args_dont_share_engine_caching(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")

engine_cache_dir = "/tmp/test_different_args_dont_share_engine_caching"
if os.path.exists(engine_cache_dir):
shutil.rmtree(engine_cache_dir)

inputs = [torch.rand((128, 3, 224, 224)).to("cuda")]

for i in range(2):
if i == 0:
strip_engine_weights = False
else:
strip_engine_weights = True

compiled_model = torch.compile(
pyt_model,
backend="tensorrt",
options={
"use_python_runtime": True,
"enabled_precisions": {torch.float},
"debug": False,
"min_block_size": 1,
"cache_built_engines": True,
"reuse_cached_engines": True,
"engine_cache_dir": engine_cache_dir,
"strip_engine_weights": strip_engine_weights,
},
)
compiled_model(*inputs)

assertions.assertEqual(
len(os.listdir(engine_cache_dir)),
2,
msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 2 engines",
)