Skip to content

✨[Feature] Support torch.ops.aten.repeat.default converter #3172

Open
@dgcnz

Description

@dgcnz

Tasks

  • torch_tensorrt.dynamo.conversion.impl.slice.ops.repeat
    • Reference:
      def expand(
      ctx: ConversionContext,
      target: Target,
      source_ir: Optional[SourceIR],
      name: str,
      input_t: TRTTensor,
      shape: Shape,
      ) -> TRTTensor:
      shape_rank = len(shape)
      initial_tensor_rank = len(input_t.shape)
      # If the rank of the input tensor is less than the shape's rank, pad with ones
      if initial_tensor_rank < shape_rank:
      input_t = prepend_ones(
      ctx.net,
      input_t,
      name + "_expand_broadcast",
      shape_rank - initial_tensor_rank,
      )
      # If the rank of the input tensor is more than the shape's rank, raise error
      elif initial_tensor_rank > shape_rank:
      raise RuntimeError(
      f"expand called with {shape_rank}-dimensional shape on Tensor with {len(shape)} dimensions. "
      "Cannot expand to shape with rank smaller than original tensor."
      )
      # After the above padding, the shape and tensor rank must be equal
      assert len(input_t.shape) == shape_rank
      # Configure the start, strides and output shape tensors
      start = tuple([0] * shape_rank)
      # stride[dim]=0 implies that dimension is being broadcasted.
      # stride should be 1 for all non-broadcasted dims
      stride = []
      input_tensor_shape = tuple(input_t.shape)
      for i, o in zip(input_tensor_shape, shape):
      # If input dim and target shape dim are static, broadcast if they are not equal
      # If input dim is known and target shape dim is dynamic we treat it as a broadcasted dim
      if (
      isinstance(i, int)
      and i != DYNAMIC_DIM
      and isinstance(o, int)
      and o != DYNAMIC_DIM
      ):
      stride.append(int(i == o))
      elif isinstance(i, int) and i != DYNAMIC_DIM and isinstance(o, TRTTensor):
      stride.append(0)
      else:
      # No broadcasting is happening. The output should have the same size as input at this dimension.
      stride.append(1)
      # Resolve dynamic dimensions in the target shape. These are not broadcasted dims.
      # The value at this dimension should be same as input.
      target_shape = []
      for i in range(shape_rank):
      if shape[i] == DYNAMIC_DIM:
      target_shape.append(
      get_shape(ctx, target, source_ir, name + f"_shape_dim{i}", input_t, i)
      )
      else:
      target_shape.append(shape[i])
      target_shape_t = target_shape
      # Handle dynamic shapes case where shape has dynamic dimension
      if any(isinstance(ele, TRTTensor) for ele in target_shape_t):
      target_shape_t = cat(
      ctx,
      target,
      source_ir,
      name + "_shape_concat",
      target_shape_t,
      0,
      cast_dtype=trt.int32,
      )
      start_tensor = cat(
      ctx,
      target,
      source_ir,
      name + "_start_concat",
      start,
      0,
      cast_dtype=trt.int32,
      )
      stride_tensor = cat(
      ctx,
      target,
      source_ir,
      name + "_stride_concat",
      stride,
      0,
      cast_dtype=trt.int32,
      )
      layer = ctx.net.add_slice(
      input_t, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
      )
      layer.set_input(1, start_tensor)
      layer.set_input(2, target_shape_t)
      layer.set_input(3, stride_tensor)
      else:
      layer = ctx.net.add_slice(
      input_t, start=start, shape=target_shape_t, stride=stride
      )
      set_layer_name(layer, target, name, source_ir)
      return layer.get_output(0)
  • torch_tensorrt.dynamo.conversion.aten_ops_converters.aten_ops_repeat
    @dynamo_tensorrt_converter(torch.ops.aten.expand.default, supports_dynamic_shapes=True)
    @enforce_tensor_types(
    {
    0: (TRTTensor,),
    }
    )
    def aten_ops_expand(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
    ) -> Union[TRTTensor, Sequence[TRTTensor]]:
    return impl.slice.expand(
    ctx,
    target,
    SourceIR.ATEN,
    name,
    args[0],
    args[1],
    )
  • Unit test, example: expand

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions