Skip to content

[MLIR] Minor fixes to FoldTransposeBroadcast rewrite (NFC) #140083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 19, 2025

Conversation

momchil-velikov
Copy link
Collaborator

@momchil-velikov momchil-velikov commented May 15, 2025

This patch contains two minor changes, which I believe were the original author's intent.

  • when folding transpose(broadcast(x)) emit broadcast(x) instead of broadcast(broadcast(x)). The latter causes intermittent verifier failures, e.g.
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({
^bb0(%arg0: vector<4x1x1x7xi8>):
  %0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8>
  %1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8>
  "func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> ()
}) : () -> ()
  • when checking permutation groups the variable low was set just once to zero, thus checking was quadratic. It looks the intent was for low to track the beginning of each dimension groups. (Nevertheless the check was correct).

This patch contains to minor changes, which I believe were
the original author's intent.

* when folding `transpose(broadcast(x))` emit `broadcast(x)`
  instead of `broadcast(broadca(x))`. The later causes intermittent
  verifier failures, e.g.
```
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({
^bb0(%arg0: vector<4x1x1x7xi8>):
  %0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8>
  %1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8>
  "func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> ()
}) : () -> ()
```

* when checking permutation groups the variable `low` was set just
  once to zero, thus checking was quadratic. It looks the intent was
  for `low` to track the beginning of each dimension groups.
  (Nevertheless the check was correct).
@llvmbot
Copy link
Member

llvmbot commented May 15, 2025

@llvm/pr-subscribers-mlir-vector

Author: Momchil Velikov (momchil-velikov)

Changes

This patch contains two minor changes, which I believe were the original author's intent.

  • when folding transpose(broadcast(x)) emit broadcast(x) instead of broadcast(broadcast(x)). The latter causes intermittent verifier failures, e.g.
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() &lt;{function_type = (vector&lt;4x1x1x7xi8&gt;) -&gt; vector&lt;3x2x4x5x6x7xi8&gt;, sym_name = "broadcast_transpose_mixed_example"}&gt; ({
^bb0(%arg0: vector&lt;4x1x1x7xi8&gt;):
  %0 = "vector.broadcast"(%arg0) : (vector&lt;4x1x1x7xi8&gt;) -&gt; vector&lt;2x3x4x5x6x7xi8&gt;
  %1 = "vector.broadcast"(%0) : (vector&lt;2x3x4x5x6x7xi8&gt;) -&gt; vector&lt;3x2x4x5x6x7xi8&gt;
  "func.return"(%1) : (vector&lt;3x2x4x5x6x7xi8&gt;) -&gt; ()
}) : () -&gt; ()
  • when checking permutation groups the variable low was set just once to zero, thus checking was quadratic. It looks the intent was for low to track the beginning of each dimension groups. (Nevertheless the check was correct).

Full diff: https://github.com/llvm/llvm-project/pull/140083.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+3-2)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..7ae43b64a5deb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6201,7 +6201,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
     bool inputIsScalar = !inputType;
     if (inputIsScalar) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                       transpose.getVector());
+                                                       broadcast.getSource());
       return success();
     }
 
@@ -6227,6 +6227,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
                 transpose, "permutation not local to group");
           }
         }
+        low = high;
       }
     }
 
@@ -6241,7 +6242,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
            "not broadcastable directly to transpose output");
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                     transpose.getVector());
+                                                     broadcast.getSource());
 
     return success();
   }

@llvmbot
Copy link
Member

llvmbot commented May 15, 2025

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

Changes

This patch contains two minor changes, which I believe were the original author's intent.

  • when folding transpose(broadcast(x)) emit broadcast(x) instead of broadcast(broadcast(x)). The latter causes intermittent verifier failures, e.g.
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() &lt;{function_type = (vector&lt;4x1x1x7xi8&gt;) -&gt; vector&lt;3x2x4x5x6x7xi8&gt;, sym_name = "broadcast_transpose_mixed_example"}&gt; ({
^bb0(%arg0: vector&lt;4x1x1x7xi8&gt;):
  %0 = "vector.broadcast"(%arg0) : (vector&lt;4x1x1x7xi8&gt;) -&gt; vector&lt;2x3x4x5x6x7xi8&gt;
  %1 = "vector.broadcast"(%0) : (vector&lt;2x3x4x5x6x7xi8&gt;) -&gt; vector&lt;3x2x4x5x6x7xi8&gt;
  "func.return"(%1) : (vector&lt;3x2x4x5x6x7xi8&gt;) -&gt; ()
}) : () -&gt; ()
  • when checking permutation groups the variable low was set just once to zero, thus checking was quadratic. It looks the intent was for low to track the beginning of each dimension groups. (Nevertheless the check was correct).

Full diff: https://github.com/llvm/llvm-project/pull/140083.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+3-2)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..7ae43b64a5deb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6201,7 +6201,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
     bool inputIsScalar = !inputType;
     if (inputIsScalar) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                       transpose.getVector());
+                                                       broadcast.getSource());
       return success();
     }
 
@@ -6227,6 +6227,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
                 transpose, "permutation not local to group");
           }
         }
+        low = high;
       }
     }
 
@@ -6241,7 +6242,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
            "not broadcastable directly to transpose output");
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                     transpose.getVector());
+                                                     broadcast.getSource());
 
     return success();
   }

@joker-eph
Copy link
Collaborator

Can you provide a test that triggered the verification error?

@momchil-velikov
Copy link
Collaborator Author

Can you provide a test that triggered the verification error?

Yes, it from the test from the original PR #135096
https://github.com/llvm/llvm-project/blob/540cf25a6df56fa1810a7411477dca9896aeed20/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

@newling
Copy link
Contributor

newling commented May 15, 2025

Thanks @momchil-velikov I think you are correct. What I don't understand is why there was no verification error, every time. For example in

func.func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {

It was converting

  %bcast = vector.broadcast %value : f32 to vector<8x1xf32>
  %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>

to

 %bcast = vector.broadcast %value : f32 to vector<8x1xf32>
  %t = vector.broadcast %bcast  : vector<8x1xf32> to vector<1x8xf32>

but that %t is invalid. Unless the folder runs before verification? But yeah, it should have going directly to

  %t = vector.broadcast %value : f32 to vector<1x8xf32>

@newling
Copy link
Contributor

newling commented May 15, 2025

when checking permutation groups the variable low was set just once to zero, thus checking was quadratic. It looks the intent was for low to track the beginning of each dimension groups. (Nevertheless the check was correct).

Yes, good catch thanks.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix. I find it worrying that this didn't create a verification error every time (as per comment above). I'll spend some time trying to figure out why not (any suggestions @joker-eph ?) maybe if you can too @momchil-velikov that'll be useful (I'm not sure I'll manage to reproduce any verification error with my setup).

nit: technically not (NFC) ?.

@newling
Copy link
Contributor

newling commented May 15, 2025

mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8>, sym_name = "broadcast_transpose_mixed_example"}> ({
^bb0(%arg0: vector<4x1x1x7xi8>):
%0 = "vector.broadcast"(%arg0) : (vector<4x1x1x7xi8>) -> vector<2x3x4x5x6x7xi8>
%1 = "vector.broadcast"(%0) : (vector<2x3x4x5x6x7xi8>) -> vector<3x2x4x5x6x7xi8>
"func.return"(%1) : (vector<3x2x4x5x6x7xi8>) -> ()
}) : () -> ()

I can see this everytime in when running mlir-opt with the --debug flag.

@newling
Copy link
Contributor

newling commented May 15, 2025

The latter causes intermittent verifier failures, e.g.

I guess, what do you mean by intermittent? This fixes a bug (thanks again!) I'm just not sure it would cause anything observable like intermittent compilation failure. My rough understanding is that the verifier is only called with --debug because the printer verifies, but as the 2 broadcasts fold, by the time the pattern rewrite is done the IR is valid again.

@momchil-velikov
Copy link
Collaborator Author

The latter causes intermittent verifier failures, e.g.

I guess, what do you mean by intermittent?

I guess bad wording. I saw it with --debug happening at one point and disappearing before the end of the compilation.

@newling
Copy link
Contributor

newling commented May 15, 2025

The latter causes intermittent verifier failures, e.g.

I guess, what do you mean by intermittent?

I guess bad wording. I saw it with --debug happening at one point and disappearing before the end of the compilation.

Ok, good. Thanks!

@banach-space
Copy link
Contributor

Can you try building mlir-opt with -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=On? With expensive checks on, mlir-opt should fail.

@newling
Copy link
Contributor

newling commented May 16, 2025

Can you try building mlir-opt with -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=On? With expensive checks on, mlir-opt should fail.

I've just tried this and yes, it fails. That's a useful flag, thanks for point this out @banach-space! Should there be a build in CI with this on?

@banach-space
Copy link
Contributor

Can you try building mlir-opt with -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=On? With expensive checks on, mlir-opt should fail.

I've just tried this and yes, it fails. That's a useful flag, thanks for point this out @banach-space! Should there be a build in CI with this on?

All the builders are defined here:

Unfortunately, searching for MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS gives no hits. I'm hesitant to enable this in one of the existing Arm builders, since I worry it might be too fragile and break frequently.

I'll look into whether we can add a dedicated builder for this - but please don’t hold your breath; these things can take time.

@dcaballe
Copy link
Contributor

Let's see if I follow... we initially generate:

 %bcast = vector.broadcast %value : f32 to vector<8x1xf32>
  %t = vector.broadcast %bcast  : vector<8x1xf32> to vector<1x8xf32>

but the CHECK rules don't fail:

//  CHECK-SAME: (%[[ARG:.+]]: f32)
//       CHECK:   %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
//       CHECK:   return %[[V]] : vector<1x8xf32>

because there is another folding that is turning the two invalid vector.broadcast ops into a single but valid one, and the intermediate state is not verified without the expensive checks?

Yeah, having a buildbot covering this would be awesome.

@momchil-velikov momchil-velikov merged commit 38d2306 into llvm:main May 19, 2025
15 checks passed
@newling
Copy link
Contributor

newling commented May 19, 2025

Can you try building mlir-opt with -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=On? With expensive checks on, mlir-opt should fail.

I've just tried this and yes, it fails. That's a useful flag, thanks for point this out @banach-space! Should there be a build in CI with this on?

All the builders are defined here:

Unfortunately, searching for MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS gives no hits. I'm hesitant to enable this in one of the existing Arm builders, since I worry it might be too fragile and break frequently.

I'll look into whether we can add a dedicated builder for this - but please don’t hold your breath; these things can take time.

Thanks @banach-space that might be useful, do let me know if I can help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants