-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][linalg] Add pure tensor check for winogradConv2DHelper
#142299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR adds pure tensor semantics check for `winogradConv2DHelper` to prevent a crash.
@llvm/pr-subscribers-mlir-linalg Author: Longsheng Mou (CoTinker) ChangesThis PR adds pure tensor semantics check for Full diff: https://github.com/llvm/llvm-project/pull/142299.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index c6ebd3a53d981..e4221d4748415 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -904,6 +904,10 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
static FailureOr<Operation *>
winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
int64_t m, int64_t r) {
+ if (!convOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
+
Value input = convOp.getInputs()[0];
Value filter = convOp.getInputs()[1];
Value output = convOp.getOutputs()[0];
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index c10e0ccebfd7c..1de861e653005 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -61,6 +61,22 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @conv2d_unsupported_type(%arg0: memref<2x10x10x5xf32>, %arg1: memref<2x3x3x5xf32>, %arg2: memref<2x8x8x2xf32>) {
+ linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : memref<2x10x10x5xf32>, memref<2x3x3x5xf32>) outs(%arg2 : memref<2x8x8x2xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @+1 {{apply Winograd Conv2D failed}}
+ %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
return %0 : tensor<2x?x?x2xf32>
|
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR adds pure tensor semantics check for Full diff: https://github.com/llvm/llvm-project/pull/142299.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index c6ebd3a53d981..e4221d4748415 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -904,6 +904,10 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
static FailureOr<Operation *>
winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
int64_t m, int64_t r) {
+ if (!convOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
+
Value input = convOp.getInputs()[0];
Value filter = convOp.getInputs()[1];
Value output = convOp.getOutputs()[0];
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index c10e0ccebfd7c..1de861e653005 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -61,6 +61,22 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @conv2d_unsupported_type(%arg0: memref<2x10x10x5xf32>, %arg1: memref<2x3x3x5xf32>, %arg2: memref<2x8x8x2xf32>) {
+ linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : memref<2x10x10x5xf32>, memref<2x3x3x5xf32>) outs(%arg2 : memref<2x8x8x2xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @+1 {{apply Winograd Conv2D failed}}
+ %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
return %0 : tensor<2x?x?x2xf32>
|
This PR adds pure tensor semantics check for
winogradConv2DHelper
to prevent a crash. Fixes #141566.