-
Notifications
You must be signed in to change notification settings - Fork 639
Open
Labels
model supportHub issue for progress on adding support for a specific modelHub issue for progress on adding support for a specific model
Description
Find this bug when fixing slice and copy shape * issue #1953. PR: #1970
Success: test_slicecopy.py
FAIL: test_maked_fill.py
KEY PYTHORN PART: x_new[..., 0] = 1
In t5_model:
%144 = torch.aten.masked_fill_.Scalar %134, %143, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor
e2e tests:
class SliceCopyMaskedFillModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 4], torch.float32, True),
])
def forward(self, x):
x_new = x.new_zeros(x.shape) # tensor([[0, 0, 0, 0]])
x_new[..., 0] = 1 # tensor([[1, 0, 0, 0]])
return x_new
@register_test_case(module_factory=lambda: SliceCopyMaskedFillModule())
def SliceCopyMaskedFillModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 4))
Run e2e test:
python -m e2e_testing.main -c tosa -f "SliceCopyMaskedFillModule" -v
2023-03-27 15:25:17.594159: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-27 15:25:17.682064: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-03-27 15:25:17.682082: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2023-03-27 15:25:18.338940: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-03-27 15:25:18.338991: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-03-27 15:25:18.339018: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Compiling SliceCopyMaskedFillModule_basic...
XFAIL - "SliceCopyMaskedFillModule_basic"
Summary:
Expectedly Failed: 1
torchscript to torchbackend :
torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints,torch.aten.native_layer_norm,torch.aten.linear})' /tmp/SliceCopyMaskedFillModule.mlir -mlir-print-ir-after-failure -mlir-disable-threading
/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/slice_like.py:574:16: error: unsupported by backend contract: tensor with unknown rank
x_new = x.new_zeros(x.shape) # tensor([[0, 0, 0, 0]])
^
/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/slice_like.py:574:16: note: see current operation: %9 = "torch.tensor_static_info_cast"(%8) : (!torch.vtensor<[1,4],f32>) -> !torch.vtensor<*,f32>
/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/slice_like.py:574:16: note: this is likely due to a missing transfer function in abstract_interp_lib_gen.py
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) //----- //
module attributes {torch.debug_module_name = "SliceCopyMaskedFillModule"} {
func.func @forward(%arg0: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<*,f32> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%int4 = torch.constant.int 4
%int6 = torch.constant.int 6
%none = torch.constant.none
%int-1 = torch.constant.int -1
%0 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,4],f32>
%2 = torch.tensor_static_info_cast %1 : !torch.vtensor<[1,4],f32> to !torch.vtensor<*,f32>
%3 = torch.copy.to_tensor %2 : !torch.tensor<*,f32>
%4 = torch.aten.slice.Tensor %3, %int-1, %int0, %int1, %int1 : !torch.tensor<*,f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,1],f32>
%5 = torch.aten.squeeze.dim %4, %int-1 : !torch.tensor<[1,1],f32>, !torch.int -> !torch.tensor<[1],f32>
%6 = torch.tensor_static_info_cast %5 : !torch.tensor<[1],f32> to !torch.tensor<*,f32>
%7 = torch.copy.to_vtensor %6 : !torch.vtensor<*,f32>
%8 = torch.prim.device %7 : !torch.vtensor<*,f32> -> !torch.Device
%9 = torch.aten.tensor.int %int1, %int6, %8, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[],f32>
%10 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%11 = torch.aten.broadcast_to %9, %10 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[1],f32>
%12 = torch.tensor_static_info_cast %11 : !torch.vtensor<[1],f32> to !torch.vtensor<*,f32>
torch.overwrite.tensor.contents %12 overwrites %6 : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
%13 = torch.copy.to_vtensor %3 : !torch.vtensor<*,f32>
return %13 : !torch.vtensor<*,f32>
}
}
@ramiro050 any idea to fix this?
Metadata
Metadata
Assignees
Labels
model supportHub issue for progress on adding support for a specific modelHub issue for progress on adding support for a specific model