Open
Description
I have the following pytorch model:
import torch
import torch_mlir
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x : torch.Tensor, y : torch.Tensor):
x[0, :, :] = y
return x.sum()
model = Model()
model.train(False)
inputs = [
torch.rand((3, 224, 224), dtype=torch.float32),
torch.rand((224, 224), dtype=torch.float32),
]
module = torch_mlir.compile(model, inputs, use_tracing=True, verbose=True)
print(module)
When I run this, it gives me the following error:
<unknown>:0: error: unsupported by backend contract: tensor with unknown rank
<unknown>:0: note: see current operation: %5 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[3,224,224],f32>) -> !torch.vtensor<*,f32>
<unknown>:0: note: this is likely due to a missing transfer function in abstract_interp_lib_gen.py
This is the IR at the start:
module attributes {torch.debug_module_name = "Model"} {
func.func private @__torch__.Model.forward(%arg0: !torch.nn.Module<"__torch__.Model">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[3,224,224],f32>}, %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[224,224],f32>}) -> !torch.tensor {
%1 = torch.tensor_static_info_cast %arg1 : !torch.tensor to !torch.tensor<[3,224,224],f32>
%2 = torch.tensor_static_info_cast %arg2 : !torch.tensor to !torch.tensor<[224,224],f32>
%int0 = torch.constant.int 0
%int0_0 = torch.constant.int 0
%3 = torch.aten.select.int %1, %int0, %int0_0 : !torch.tensor<[3,224,224],f32>, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
%int0_1 = torch.constant.int 0
%int0_2 = torch.constant.int 0
%int9223372036854775807 = torch.constant.int 9223372036854775807
%int1 = torch.constant.int 1
%4 = torch.aten.slice.Tensor %3, %int0_1, %int0_2, %int9223372036854775807, %int1 : !torch.tensor<[224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
%int1_3 = torch.constant.int 1
%int0_4 = torch.constant.int 0
%int9223372036854775807_5 = torch.constant.int 9223372036854775807
%int1_6 = torch.constant.int 1
%5 = torch.aten.slice.Tensor %4, %int1_3, %int0_4, %int9223372036854775807_5, %int1_6 : !torch.tensor<[224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
%false_7 = torch.constant.bool false
%6 = torch.aten.copy_ %5, %2, %false_7 : !torch.tensor<[224,224],f32>, !torch.tensor<[224,224],f32>, !torch.bool -> !torch.tensor<[224,224],f32>
%none_8 = torch.constant.none
%7 = torch.aten.sum %1, %none_8 : !torch.tensor<[3,224,224],f32>, !torch.none -> !torch.tensor<[],f32>
%8 = torch.tensor_static_info_cast %7 : !torch.tensor<[],f32> to !torch.tensor
return %8 : !torch.tensor
}
torch.class_type @__torch__.Model {
torch.attr private "training" : !torch.bool
torch.attr private "_is_full_backward_hook" : !torch.optional<bool>
torch.method "forward", @__torch__.Model.forward
}
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.nn_module {
torch.slot "training", %false : !torch.bool
torch.slot "_is_full_backward_hook", %none : !torch.none
} : !torch.nn.Module<"__torch__.Model">
}
The is the IR at LowerToBackendContract (which fails):
module attributes {torch.debug_module_name = "Model"} {
func.func @forward(%arg0: !torch.vtensor<[3,224,224],f32>, %arg1: !torch.vtensor<[224,224],f32>) -> !torch.vtensor<[],f32> {
%int224 = torch.constant.int 224
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int9223372036854775807 = torch.constant.int 9223372036854775807
%none = torch.constant.none
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[3,224,224],f32> to !torch.vtensor<*,f32>
%1 = torch.copy.to_tensor %0 : !torch.tensor<*,f32>
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<*,f32> to !torch.tensor<[3,224,224],f32>
%3 = torch.aten.slice.Tensor %2, %int0, %int0, %int1, %int1 : !torch.tensor<[3,224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,224,224],f32>
%4 = torch.aten.squeeze.dim %3, %int0 : !torch.tensor<[1,224,224],f32>, !torch.int -> !torch.tensor<[224,224],f32>
%5 = torch.aten.slice.Tensor %4, %int0, %int0, %int9223372036854775807, %int1 : !torch.tensor<[224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
%6 = torch.aten.slice.Tensor %5, %int1, %int0, %int9223372036854775807, %int1 : !torch.tensor<[224,224],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[224,224],f32>
%7 = torch.prim.ListConstruct %int224, %int224 : (!torch.int, !torch.int) -> !torch.list<int>
%8 = torch.aten.broadcast_to %arg1, %7 : !torch.vtensor<[224,224],f32>, !torch.list<int> -> !torch.vtensor<[224,224],f32>
torch.overwrite.tensor.contents %8 overwrites %6 : !torch.vtensor<[224,224],f32>, !torch.tensor<[224,224],f32>
%9 = torch.copy.to_vtensor %2 : !torch.vtensor<[3,224,224],f32>
%10 = torch.aten.sum %9, %none : !torch.vtensor<[3,224,224],f32>, !torch.none -> !torch.vtensor<[],f32>
return %10 : !torch.vtensor<[],f32>
}
}
After talking to @ramiro050, he mentioned that this is because pytorch breaks this index put operation into a bunch of slice + copy operations and then because the sum is performed on the mutated tensor, value semantics cannot be achieved. He mentioned that #1901 is trying to fix this, but that this is a more complicated case which has multiple slice operations that should be matched.
Thanks in advance for your help! 😄
Metadata
Metadata
Assignees
Labels
No labels