Skip to content

Commit 2ea0724

Browse files
authored
Merge pull request llvm#33 from clang-ykt/fix-gemm
Fix Gemm translation to ONNX dialect.
2 parents b50fc1f + 514cbcb commit 2ea0724

File tree

5 files changed

+11
-11
lines changed

5 files changed

+11
-11
lines changed

src/dialect/onnx/onnx.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
9090
// or outputs. This decision affects only ONNX operations with optional
9191
// arguments not ONNX operations with variadic operands.
9292

93-
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
93+
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
9494
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
95-
let summary = "ONNX general matrix multiply operation";
95+
let summary = "ONNX general matrix multiply operation without bias.";
9696
let description = [{
9797

98-
The "onnx.gemm" generic matrix multiplication with bias.
98+
The "onnx.Gemm" generic matrix multiplication without bias.
9999

100100
}];
101101

102-
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
103-
let results = (outs AnyTensor);
102+
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in);
103+
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
104104
}
105105

106106
def ONNXConv1Op:ONNX_Op<"Conv1",

src/dialect/onnx/onnx_ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ void ONNXGemmOp::inferShapes() {
347347
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
348348
}
349349

350-
// FullGemm
350+
// GemmNoBias
351351

352-
void ONNXFullGemmOp::inferShapes() {
352+
void ONNXGemmNoBiasOp::inferShapes() {
353353
// Cannot infer shape if no shape exists.
354354
if (!getOperand(0).getType().isa<RankedTensorType>() ||
355355
!getOperand(1).getType().isa<RankedTensorType>())

src/pass/onnx_combine.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
3030
// Pattern-Match and Rewrite
3131
//===----------------------------------------------------------------------===//
3232

33-
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z)
33+
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
3434
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
35-
(ONNXFullGemmOp $m1, $m2, $m3),
35+
(ONNXGemmOp $m1, $m2, $m3),
3636
[(HasOneUse $res)]>;
3737

3838
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)

src/pass/shape_inference_pass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
114114
op->getName().getStringRef() != "onnx.Identity" &&
115115
op->getName().getStringRef() != "onnx.MatMul" &&
116116
op->getName().getStringRef() != "onnx.Gemm" &&
117-
op->getName().getStringRef() != "onnx.FullGemm" &&
117+
op->getName().getStringRef() != "onnx.GemmNoBias" &&
118118
op->getName().getStringRef() != "onnx.Reshape" &&
119119
op->getName().getStringRef() != "onnx.Transpose")
120120
return false;

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
44
// CHECK-LABEL: test_matmul_add_simplification
5-
// CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
5+
// CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
66
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
77
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
88
"std.return"(%1) : (tensor<10x10xf32>) -> ()

0 commit comments

Comments
 (0)