Skip to content

✨[Feature] Remove requirement for require_full_compilation=False when using input_signature #1602

Closed
@gs-olive

Description

@gs-olive

Feature Context

Models which are fully supported in TRT, except for their input type being a collection should be able to be fully-compiled in Torch-TRT. Considering that Torch-executed list packing and list unpacking code is already being inserted (by necessity) even when models are fully supported, there should not be a need to disable full compilation when providing complex input types. Additionally, operators including prim::ListUnpack should not be added to torch_executed_ops automatically upon using input_signature, as they are currently, since evaluators for them exist.

Desired Solution

The preferred solution is to remove the requirement for require_full_compilation=False when using input_signature and to remove the requirement that collection-based operators be executed in fallback:

elif compile_spec["input_signature"] is not None:
log(
Level.Warning,
"Input signature parsing is an experimental feature, behavior and APIs may change",
)
signature = _parse_input_signature(compile_spec["input_signature"])
info.input_signature = _C.InputSignature(signature) # py_object
if not compile_spec["torch_fallback"]["enabled"]:
raise ValueError(
"Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release"
)
log(
Level.Debug,
"Grouped inputs currently requires additional settings to enable the feature",
)
log(
Level.Debug,
"""Adding the following ops to torch_executed_ops:
- aten::__getitem__
- prim::ListConstruct
- prim::ListUnpack
- prim::TupleIndex
- prim::TupleConstruct
- prim::TupleUnpack
""",
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"aten::__getitem__"
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"prim::ListConstruct"
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack")
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex")
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"prim::TupleConstruct"
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"prim::TupleUnpack"
)

This would require modification of the C++ core code as well, to ensure that relaxing this requirement will not cause further issues with the existing compilation phases.

Additional Context

A proof-of-concept for this feature already exists in PR #1599, which could be used as a template to enable full-compilation functionality for collection inputs as well. This would complete the plan for Collection IO as discussed in #629 (comment).

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions