Skip to content

[mlir][spirv] Implement UMod canonicalization for vector constants #141902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

fairywreath
Copy link
Contributor

@fairywreath fairywreath commented May 29, 2025

Closes #63174.

Implements this transformation pattern, which is currently only applied to scalars, for vectors:

%1 = "spirv.UMod"(%0, %CONST_32) : (i32, i32) -> i32
%2 = "spirv.UMod"(%1, %CONST_4) : (i32, i32) -> i32

to

%1 = "spirv.UMod"(%0, %CONST_32) : (i32, i32) -> i32
%2 = "spirv.UMod"(%0, %CONST_4) : (i32, i32) -> i32

Additionally fixes and issue where patterns like this:

%1 = "spirv.UMod"(%0, %CONST_4) : (i32, i32) -> i32
%2 = "spirv.UMod"(%1, %CONST_32) : (i32, i32) -> i32

were incorrectly canonicalized to:

%1 = "spirv.UMod"(%0, %CONST_4) : (i32, i32) -> i32
%2 = "spirv.UMod"(%0, %CONST_32) : (i32, i32) -> i32

which is incorrect since (X % A) % B == (X % B) IFF A is a multiple of B, i.e., B divides A.

@llvmbot
Copy link
Member

llvmbot commented May 29, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Darren Wihandi (fairywreath)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/141902.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+18-9)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+21-6)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index e36d4b910193e..89b46577f061c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -326,7 +326,6 @@ void spirv::UMulExtendedOp::getCanonicalizationPatterns(
 
 // The transformation is only applied if one divisor is a multiple of the other.
 
-// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
 struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -336,19 +335,29 @@ struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
     if (!prevUMod)
       return failure();
 
-    IntegerAttr prevValue;
-    IntegerAttr currValue;
+    TypedAttr prevValue;
+    TypedAttr currValue;
     if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
         !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
       return failure();
 
-    APInt prevConstValue = prevValue.getValue();
-    APInt currConstValue = currValue.getValue();
+    // Ensure that previous divisor is a multiple of the current divisor. If
+    // not, fail the transformation.
+    bool isApplicable = false;
+    if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
+      auto currInt = dyn_cast<IntegerAttr>(currValue);
+      isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
+    } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
+      auto currVec = dyn_cast<DenseElementsAttr>(currValue);
+      isApplicable = llvm::all_of(
+          llvm::zip(prevVec.getValues<APInt>(), currVec.getValues<APInt>()),
+          [](auto pair) {
+            const auto &[a, b] = pair;
+            return a.urem(b) == 0;
+          });
+    }
 
-    // Ensure that one divisor is a multiple of the other. If not, fail the
-    // transformation.
-    if (prevConstValue.urem(currConstValue) != 0 &&
-        currConstValue.urem(prevConstValue) != 0)
+    if (!isApplicable)
       return failure();
 
     // The transformation is safe. Replace the existing UMod operation with a
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0fd6c18a6c241..52c915bfebc66 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -967,17 +967,17 @@ func.func @umod_fold(%arg0: i32) -> (i32, i32) {
   return %0, %1: i32, i32
 }
 
-// CHECK-LABEL: @umod_fail_vector_fold
+// CHECK-LABEL: @umod_vector_fold
 // CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
-func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+func.func @umod_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
   // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
   // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
   %const1 = spirv.Constant dense<32> : vector<4xi32>
   %0 = spirv.UMod %arg0, %const1 : vector<4xi32>
-  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
   %const2 = spirv.Constant dense<4> : vector<4xi32>
   %1 = spirv.UMod %0, %const2 : vector<4xi32>
-  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]]
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
   // CHECK: return %[[UMOD0]], %[[UMOD1]]
   return %0, %1: vector<4xi32>, vector<4xi32>
 } 
@@ -996,9 +996,9 @@ func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) {
   return %0, %1: i32, i32
 }
 
-// CHECK-LABEL: @umod_fail_fold
+// CHECK-LABEL: @umod_fail_1_fold
 // CHECK-SAME: (%[[ARG:.*]]: i32)
-func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
+func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) {
   // CHECK: %[[CONST5:.*]] = spirv.Constant 5
   // CHECK: %[[CONST32:.*]] = spirv.Constant 32
   %const1 = spirv.Constant 32 : i32
@@ -1011,6 +1011,21 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
   return %0, %1: i32, i32
 }
 
+// CHECK-LABEL: @umod_fail_2_fold
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
+func.func @umod_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+  // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
+  // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
+  %const1 = spirv.Constant dense<4> : vector<4xi32>
+  %0 = spirv.UMod %arg0, %const1 : vector<4xi32>
+  // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
+  %const2 = spirv.Constant dense<32> : vector<4xi32>
+  %1 = spirv.UMod %0, %const2 : vector<4xi32>
+  // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
+  // CHECK: return %[[UMOD0]], %[[UMOD1]]
+  return %0, %1: vector<4xi32>, vector<4xi32>
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//

@fairywreath fairywreath marked this pull request as draft May 29, 2025 06:49
@fairywreath fairywreath marked this pull request as ready for review May 30, 2025 05:41
@IgWod-IMG
Copy link
Contributor

Me again :) I have two minor comments:

  1. Could you please reference the issues you fix ([mlir][spirv] Add support for vector of integers in UModSimplication canonicalization pattern #63174) in the description. In fact, if you use one of those magic keywords GitHub should automatically close the issue, once this change is merged.

  2. It'd probably be useful to include the pattern that was broken in the tests (if not already there), so it can catch any future regression.

@fairywreath
Copy link
Contributor Author

@IgWod-IMG Thanks for the review. The vector test I added actually tested the pattern that was broken. Anyways I have added more tests, one which is a scalar version on the pattern that was broken.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving this

@kuhar kuhar requested a review from IgWod-IMG June 2, 2025 16:50
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM*

Thanks for addressing my previous comments!

*Subject to addressing other reviewer comments :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] Add support for vector of integers in UModSimplication canonicalization pattern
4 participants