Skip to content

[TorchToLinalg] torch-mlir fails to legalize torch.transformers.grouped_mm_fallback #4618

Description

@prasadakp27

Torch-mlir fails with the following error when running the test case below:

error: failed to legalize operation 'torch.operator' that was explicitly marked illegal: %0 = "torch.operator"(%arg0, %arg1, %arg2) <{name = "torch.transformers.grouped_mm_fallback"}> : (!torch.vtensor<[4,1024],f32>, !torch.vtensor<[16,1024,512],f32>, !torch.vtensor<[16],si64>) -> !torch.vtensor<[4,512],f32>

Context: Custom op from HuggingFace transformers MoE dispatch. Rows of input are tokens sorted by expert; offs[i] is the row offset for expert i. Each row multiplies against its assigned expert weight matrix.

Command to reproduce:

torch-mlir-opt -pass-pipeline='builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{extra-library=})' test_grouped_mm.mlir

test_grouped_mm.mlir:

module {
  func.func @forward(%arg0: !torch.vtensor<[4,1024],f32>, %arg1: !torch.vtensor<[16,1024,512],f32>, %arg2: !torch.vtensor<[16],si64>) -> !torch.vtensor<[4,512],f32> {
    %0 = "torch.operator"(%arg0, %arg1, %arg2) <{name = "torch.transformers.grouped_mm_fallback"}> : (!torch.vtensor<[4,1024],f32>, !torch.vtensor<[16,1024,512],f32>, !torch.vtensor<[16],si64>) -> !torch.vtensor<[4,512],f32>
    return %0 : !torch.vtensor<[4,512],f32>
  }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions