-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][spirv] Implement UMod canonicalization for vector constants #141902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Darren Wihandi (fairywreath) ChangesFull diff: https://github.com/llvm/llvm-project/pull/141902.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index e36d4b910193e..89b46577f061c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -326,7 +326,6 @@ void spirv::UMulExtendedOp::getCanonicalizationPatterns(
// The transformation is only applied if one divisor is a multiple of the other.
-// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
using OpRewritePattern::OpRewritePattern;
@@ -336,19 +335,29 @@ struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
if (!prevUMod)
return failure();
- IntegerAttr prevValue;
- IntegerAttr currValue;
+ TypedAttr prevValue;
+ TypedAttr currValue;
if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
!matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
return failure();
- APInt prevConstValue = prevValue.getValue();
- APInt currConstValue = currValue.getValue();
+ // Ensure that previous divisor is a multiple of the current divisor. If
+ // not, fail the transformation.
+ bool isApplicable = false;
+ if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
+ auto currInt = dyn_cast<IntegerAttr>(currValue);
+ isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
+ } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
+ auto currVec = dyn_cast<DenseElementsAttr>(currValue);
+ isApplicable = llvm::all_of(
+ llvm::zip(prevVec.getValues<APInt>(), currVec.getValues<APInt>()),
+ [](auto pair) {
+ const auto &[a, b] = pair;
+ return a.urem(b) == 0;
+ });
+ }
- // Ensure that one divisor is a multiple of the other. If not, fail the
- // transformation.
- if (prevConstValue.urem(currConstValue) != 0 &&
- currConstValue.urem(prevConstValue) != 0)
+ if (!isApplicable)
return failure();
// The transformation is safe. Replace the existing UMod operation with a
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0fd6c18a6c241..52c915bfebc66 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -967,17 +967,17 @@ func.func @umod_fold(%arg0: i32) -> (i32, i32) {
return %0, %1: i32, i32
}
-// CHECK-LABEL: @umod_fail_vector_fold
+// CHECK-LABEL: @umod_vector_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
-func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+func.func @umod_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
// CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
// CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
%const1 = spirv.Constant dense<32> : vector<4xi32>
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
- // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
%const2 = spirv.Constant dense<4> : vector<4xi32>
%1 = spirv.UMod %0, %const2 : vector<4xi32>
- // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]]
+ // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
+ // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
// CHECK: return %[[UMOD0]], %[[UMOD1]]
return %0, %1: vector<4xi32>, vector<4xi32>
}
@@ -996,9 +996,9 @@ func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) {
return %0, %1: i32, i32
}
-// CHECK-LABEL: @umod_fail_fold
+// CHECK-LABEL: @umod_fail_1_fold
// CHECK-SAME: (%[[ARG:.*]]: i32)
-func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
+func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) {
// CHECK: %[[CONST5:.*]] = spirv.Constant 5
// CHECK: %[[CONST32:.*]] = spirv.Constant 32
%const1 = spirv.Constant 32 : i32
@@ -1011,6 +1011,21 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
return %0, %1: i32, i32
}
+// CHECK-LABEL: @umod_fail_2_fold
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
+func.func @umod_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+ // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
+ // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
+ %const1 = spirv.Constant dense<4> : vector<4xi32>
+ %0 = spirv.UMod %arg0, %const1 : vector<4xi32>
+ // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
+ %const2 = spirv.Constant dense<32> : vector<4xi32>
+ %1 = spirv.UMod %0, %const2 : vector<4xi32>
+ // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
+ // CHECK: return %[[UMOD0]], %[[UMOD1]]
+ return %0, %1: vector<4xi32>, vector<4xi32>
+}
+
// -----
//===----------------------------------------------------------------------===//
|
ff5cd52
to
891d5ab
Compare
Me again :) I have two minor comments:
|
@IgWod-IMG Thanks for the review. The vector test I added actually tested the pattern that was broken. Anyways I have added more tests, one which is a scalar version on the pattern that was broken. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for improving this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM*
Thanks for addressing my previous comments!
*Subject to addressing other reviewer comments :)
Closes #63174.
Implements this transformation pattern, which is currently only applied to scalars, for vectors:
to
Additionally fixes and issue where patterns like this:
were incorrectly canonicalized to:
which is incorrect since
(X % A) % B
==(X % B)
IFF A is a multiple of B, i.e., B divides A.