Skip to content

✨[Feature] Support for aten.expand where the new shape changes the rank #1948

Closed
@gs-olive

Description

@gs-olive

Context

Currently, the implementation of torch.ops.aten.expand, which is based on acc_ops.expand, requires that the rank of the shape being expanded to is the same as the rank of the input tensor (see below). This differs from the behavior of Torch, which can handle expand function calls to shapes of larger rank.

@tensorrt_converter(acc_ops.expand)
def acc_ops_expand_tensor(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
shape = list(kwargs["sizes"])
input_val = get_trt_tensor(network, input_t, f"{name}_input")
if network.has_implicit_batch_dimension:
shape = shape[1:]
ranks = len(input_val.shape)
# TRT does not support different dimension size
assert len(shape) == ranks

Valid Torch Behavior

import torch

x = torch.ones((64,))
x_new = x.expand((16, 1, 64))

Feature Proposal

Add functionality to the acc_ops_expand_tensor to automatically pad the dimension of the tensor via existing broadcast/padding utilities, so the ranks of the input tensor and expand shape agree.

Consider using IPaddingLayer, as is done here:

Additional Context

Error encountered on model:

  File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/fx/converters/aten_ops_converters.py", line 378, in aten_ops_expand
    return acc_ops_converters.acc_ops_expand_tensor(
  File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/fx/converters/acc_ops_converters.py", line 2559, in acc_ops_expand_tensor
    assert len(shape) == ranks
AssertionError: While executing %expand : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg3_1, [16, 1, 64]), kwargs = {_itensor_to_tensor_meta: {<tensorrt.tensorrt.ITensor object at 0x7f404a9e66b0>: ((64,), torch.float32, True, (1,), torch.contiguous_format, False, {})}})

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions