-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167
base: main
Are you sure you want to change the base?
Conversation
name: str = "", | ||
settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed | ||
weight_name_map: Optional[dict[Any, Any]] = None, | ||
graph_module: torch.fx.GraphModule = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan I tried to do refitting for C++ runtime like for Python runtime but didn't work. Any suggestions? should I do in C++ or Python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesnt refit already work on both apis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also why do we need the graph module in this module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
In this PR I moved the refitting part into TRTModule, so only works for Python runtime.
-
graph module is used for refitting
@@ -619,27 +609,32 @@ def run( | |||
builder_config, self.compilation_settings.timing_cache_path | |||
) | |||
|
|||
serialized_engine = self.builder.build_serialized_network( | |||
# if strip_engine_weights is true, the serialized engine need to be refitted before using | |||
maybe_unrefitted_serialized_engine = self.builder.build_serialized_network( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this maybe unrefitted engine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please see the design in the comment below. If compilation_settings.strip_engine_weights is true, it needs to be refitted, else it doesn't. so it's maybe
), "weight-stripped engines must be refittable, please set make_refittable=True" | ||
|
||
# Refit the weights | ||
refitter = trt.Refitter(self.engine, TRT_LOGGER) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use this function?
TensorRT/py/torch_tensorrt/dynamo/_refit.py
Line 138 in fa02fd3
def _refit_single_trt_engine_with_gm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function requires input_list
which is not provided in the caller.
@@ -121,6 +124,52 @@ def setup_engine(self) -> None: | |||
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) | |||
self.context = self.engine.create_execution_context() | |||
|
|||
if self.settings.strip_engine_weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We likely shouldnt be doing the refit in these modules
I think for weight stripping there are 3 workflows.
- a user just wants a weight stripped engine. They should use
convert_exported_program_to_trt_engine
with settingsstrip_weights
. The choice ofmake_refittable
can be used to decide betweenkREFIT
andkREFIT_IDENTICAL
(though it might not be entirely clear so we might want to think about that setting). - We want to utilize weight stripping to have a lighter weight cache. Here this choice is opaque to the user. The user choice of
make_refittable
controls if we usekREFIT
orkREFIT_IDENTICAL
. But once the engine is loaded or we pull from cache we immediately refit (prior to passing the engine to the TRTModule). Same as we do today - The user wants a stripped weights compiled program (im not sure why or if this is a real usecase). Here, this is basically the same as lazy engine loading. We would require that users need to run through
refit_engine_weights
before executing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. The very beginning idea/design is commented below. I'll move the refitting part back to TRTInterpreter.run()
The choice of
make_refittable
can be used to decide betweenkREFIT
andkREFIT_IDENTICAL
Do you mean we use make_refittable
to control both kREFIT
and kREFIT_IDENTICAL
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zewenli98 do you have a design for this feature?
@narendasan Ok, at first the overall design was like: In TRTInterpreter.run(): if compilation_settings.strip_engine_weights is True:
if engine_cache not hit:
1. build a weight-stripped engine
2. save the weight-stripped engine if engine_cache is set
3. return the weight-stripped engine (not yet refit)
else:
load and return the weight-stripped engine (not yet refit)
else:
if engine_cache not hit:
1. build a weight-included engine
2. save the weight-included engine if engine_cache is set
3. return the weight-included engine (don't need to refit)
else:
load and return the weight-included engine (not yet refit) Then, in TRTModule, refit if necessary before inference. |
@narendasan The design was updated. From the users' perspective, they are able to set
Besides, users can specify For the 3 workflows mentioned above,
Please see more details in the tests. |
I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing. I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.
The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support
The case for type 3 engines now is only valid if building a non refittable engine is faster than building a refit_identical engine then refitting the weights. If it is not by a significant enough margin I propose we remove that workflow and just have So assuming that we can remove type 3 engines, |
Some of the open questions are:
|
Are you referring to |
My current design is: If users specify
I also thought about it earlier. The TRT doc says "if the refit weights are not identical to the build-time weights, behavior is undefined... This enables use of a single set of weights with different inference backends, or with TensorRT plans for multiple GPU architectures."
will investigate on it. |
@narendasan I tested on building Resnet18 and vgg16 via the two paths: (1) |
@narendasan I just confirmed with TRT team, the conclusion is
I think we can rename On top of this, In summary, the 3 workflows mentioned above would be:
|
I think we should remove non-refittable then and we can add it back as a non default workflow later if theres some reason to.
I still dont know what the usecase for this is |
We should think about a solution for this since behavior is undefined |
63be853
to
5943d68
Compare
5943d68
to
aa61b6c
Compare
py/torch_tensorrt/dynamo/_refit.py
Outdated
@@ -414,6 +410,10 @@ def refit_module_weights( | |||
"The type of graph module is not supported for refitting or two compiled modules do not match." | |||
) | |||
|
|||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weight stripped engines can only be refit once?
# clear EXCLUDE_WEIGHTS flag | ||
serialization_config = engine.create_serialization_config() | ||
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) | ||
serialized_engine = engine.serialize_with_config(serialization_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we serialize then immediately deserialize here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because we want the engine to clear the EXCLUDE_WEIGHTS flag. Is there a way to clear the flag without doing serialization?
new_engine_info = list(engine_info) | ||
new_engine_info[ENGINE_IDX] = serialized_engine | ||
new_engine_info[ENGINE_IDX] = bytes(serialized_engine) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we only need to deserialize in a PythonTorchTensorRTModule and we should probably use setup_engine instead. The standard interface should be like provide the serialized engine, then setup engine will set the module up properly for both Python and C++
@@ -532,7 +548,10 @@ def run( | |||
# self.engine_cache could be None if: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are kind of avoiding this but I feel like we might want to restructure completely run
to assume refit and check if the user wants immutable weights. Also we might want to pull the cache pulling and inserting code into helpers just to make easier to understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like lines from 551-598 and 671-711 should probably be helpers that pull and insert the weight stripped engine. We should have one for single module refit as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like
if self.reuse_cached_engines:
weight_stripped_engines = _pull_cached_engine(hash, settings, inputs)
serialized_engine = fit_weights(self, weight_stripped_engine):
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(
weight_stripped_serialized_engine
)
from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)
_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)
# Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(
trt.SerializationFlag.EXCLUDE_WEIGHTS
)
serialized_engine = engine.serialize_with_config(
serialization_config
)
return serialized_engine
@@ -629,35 +657,68 @@ def run( | |||
assert serialized_engine | |||
|
|||
_LOGGER.info( | |||
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" | |||
f"Build weight-stripped TRT engine elapsed time: {datetime.now() - build_engine_start_time}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably just leave this as is, if the user requests immutable_weights, it wouldnt apply
weight_stripped_serialized_engine = serialized_engine | ||
else: | ||
# Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared | ||
runtime = trt.Runtime(TRT_LOGGER) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we refit then strip the weights again? If refit is enabled shouldnt the builder always us a weight stripped engine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to add the refit settings to the _SETTINGS_TO_BE_ENGINE_INVARIANT
5ae553b
to
402c9b0
Compare
Description
Fixes #3146
Type of change
Checklist: