Skip to content

LowerToBackendContract fails due to index put style operation #1925

Open
@ataheridezfouli-groq

Description

@ataheridezfouli-groq

I have the following pytorch model:

import torch
import torch_mlir

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x : torch.Tensor, y : torch.Tensor):
        x[0, :, :] = y
        return x.sum()

model = Model()
model.train(False)
inputs = [
    torch.rand((3, 224, 224), dtype=torch.float32),
    torch.rand((224, 224), dtype=torch.float32),
]
module = torch_mlir.compile(model, inputs, use_tracing=True, verbose=True)
print(module)

When I run this, it gives me the following error:

<unknown>:0: error: unsupported by backend contract: tensor with unknown rank
<unknown>:0: note: see current operation: %5 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[3,224,224],f32>) -> !torch.vtensor<*,f32>
<unknown>:0: note: this is likely due to a missing transfer function in abstract_interp_lib_gen.py

This is the IR at the start:

module attributes {torch.debug_module_name = "Model"} {
  func.func private @__torch__.Model.forward(%arg0: !torch.nn.Module<"__torch__.Model">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[3,224,224],f32>}, %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[224,224],f32>}) -> !torch.tensor {
    %1 = torch.tensor_static_info_cast %arg1 : !torch.tensor to !torch.tensor<[3,224,224],f32>
    %2 = torch.tensor_static_info_cast %arg2 : !torch.tensor to !torch.tensor<[224,224],f32>
    %int0 = torch.constant.int 0
    %int0_0 = torch.constant.int 0
    %3 = torch.aten.select.int %1, %int0, %int0_0 : !torch.tensor<[3,224,224],f32>, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
    %int0_1 = torch.constant.int 0
    %int0_2 = torch.constant.int 0
    %int9223372036854775807 = torch.constant.int 9223372036854775807
    %int1 = torch.constant.int 1
    %4 = torch.aten.slice.Tensor %3, %int0_1, %int0_2, %int9223372036854775807, %int1 : !torch.tensor<[224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
    %int1_3 = torch.constant.int 1
    %int0_4 = torch.constant.int 0
    %int9223372036854775807_5 = torch.constant.int 9223372036854775807
    %int1_6 = torch.constant.int 1
    %5 = torch.aten.slice.Tensor %4, %int1_3, %int0_4, %int9223372036854775807_5, %int1_6 : !torch.tensor<[224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
    %false_7 = torch.constant.bool false
    %6 = torch.aten.copy_ %5, %2, %false_7 : !torch.tensor<[224,224],f32>, !torch.tensor<[224,224],f32>, !torch.bool -> !torch.tensor<[224,224],f32>
    %none_8 = torch.constant.none
    %7 = torch.aten.sum %1, %none_8 : !torch.tensor<[3,224,224],f32>, !torch.none -> !torch.tensor<[],f32>
    %8 = torch.tensor_static_info_cast %7 : !torch.tensor<[],f32> to !torch.tensor
    return %8 : !torch.tensor
  }
  torch.class_type @__torch__.Model {
    torch.attr private "training" : !torch.bool
    torch.attr private "_is_full_backward_hook" : !torch.optional<bool>
    torch.method "forward", @__torch__.Model.forward
  }
  %false = torch.constant.bool false
  %none = torch.constant.none
  %0 = torch.nn_module {
    torch.slot "training", %false : !torch.bool
    torch.slot "_is_full_backward_hook", %none : !torch.none
  } : !torch.nn.Module<"__torch__.Model">
}

The is the IR at LowerToBackendContract (which fails):

module attributes {torch.debug_module_name = "Model"} {
  func.func @forward(%arg0: !torch.vtensor<[3,224,224],f32>, %arg1: !torch.vtensor<[224,224],f32>) -> !torch.vtensor<[],f32> {
    %int224 = torch.constant.int 224
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %int9223372036854775807 = torch.constant.int 9223372036854775807
    %none = torch.constant.none
    %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[3,224,224],f32> to !torch.vtensor<*,f32>
    %1 = torch.copy.to_tensor %0 : !torch.tensor<*,f32>
    %2 = torch.tensor_static_info_cast %1 : !torch.tensor<*,f32> to !torch.tensor<[3,224,224],f32>
    %3 = torch.aten.slice.Tensor %2, %int0, %int0, %int1, %int1 : !torch.tensor<[3,224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,224,224],f32>
    %4 = torch.aten.squeeze.dim %3, %int0 : !torch.tensor<[1,224,224],f32>, !torch.int -> !torch.tensor<[224,224],f32>
    %5 = torch.aten.slice.Tensor %4, %int0, %int0, %int9223372036854775807, %int1 : !torch.tensor<[224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
    %6 = torch.aten.slice.Tensor %5, %int1, %int0, %int9223372036854775807, %int1 : !torch.tensor<[224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
    %7 = torch.prim.ListConstruct %int224, %int224 : (!torch.int, !torch.int) -> !torch.list<int>
    %8 = torch.aten.broadcast_to %arg1, %7 : !torch.vtensor<[224,224],f32>, !torch.list<int> -> !torch.vtensor<[224,224],f32>
    torch.overwrite.tensor.contents %8 overwrites %6 : !torch.vtensor<[224,224],f32>, !torch.tensor<[224,224],f32>
    %9 = torch.copy.to_vtensor %2 : !torch.vtensor<[3,224,224],f32>
    %10 = torch.aten.sum %9, %none : !torch.vtensor<[3,224,224],f32>, !torch.none -> !torch.vtensor<[],f32>
    return %10 : !torch.vtensor<[],f32>
  }
}

After talking to @ramiro050, he mentioned that this is because pytorch breaks this index put operation into a bunch of slice + copy operations and then because the sum is performed on the mutated tensor, value semantics cannot be achieved. He mentioned that #1901 is trying to fix this, but that this is a more complicated case which has multiple slice operations that should be matched.

Thanks in advance for your help! 😄

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions