Torch-mlir fails with the following error when running the test case below:
error: failed to legalize operation 'torch.operator' that was explicitly marked illegal: %0 = "torch.operator"(%arg0, %arg1, %arg2) <{name = "torch.transformers.grouped_mm_fallback"}> : (!torch.vtensor<[4,1024],f32>, !torch.vtensor<[16,1024,512],f32>, !torch.vtensor<[16],si64>) -> !torch.vtensor<[4,512],f32>
Context: Custom op from HuggingFace transformers MoE dispatch. Rows of input are tokens sorted by expert; offs[i] is the row offset for expert i. Each row multiplies against its assigned expert weight matrix.
Command to reproduce:
torch-mlir-opt -pass-pipeline='builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{extra-library=})' test_grouped_mm.mlir
test_grouped_mm.mlir:
module {
func.func @forward(%arg0: !torch.vtensor<[4,1024],f32>, %arg1: !torch.vtensor<[16,1024,512],f32>, %arg2: !torch.vtensor<[16],si64>) -> !torch.vtensor<[4,512],f32> {
%0 = "torch.operator"(%arg0, %arg1, %arg2) <{name = "torch.transformers.grouped_mm_fallback"}> : (!torch.vtensor<[4,1024],f32>, !torch.vtensor<[16,1024,512],f32>, !torch.vtensor<[16],si64>) -> !torch.vtensor<[4,512],f32>
return %0 : !torch.vtensor<[4,512],f32>
}
}
Torch-mlir fails with the following error when running the test case below:
error: failed to legalize operation 'torch.operator' that was explicitly marked illegal: %0 = "torch.operator"(%arg0, %arg1, %arg2) <{name = "torch.transformers.grouped_mm_fallback"}> : (!torch.vtensor<[4,1024],f32>, !torch.vtensor<[16,1024,512],f32>, !torch.vtensor<[16],si64>) -> !torch.vtensor<[4,512],f32>
Context: Custom op from HuggingFace transformers MoE dispatch. Rows of input are tokens sorted by expert; offs[i] is the row offset for expert i. Each row multiplies against its assigned expert weight matrix.
Command to reproduce:
torch-mlir-opt -pass-pipeline='builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{extra-library=})' test_grouped_mm.mlir
test_grouped_mm.mlir: