Skip to content

[mlir][Affine] Add nsw to lowering of AffineMulExpr. #121535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 6, 2025

Conversation

MaheshRavishankar
Copy link
Contributor

Since index operations have no set bitwidth, it is ill-defined to use signed/unsigned wrapping behavior. The corollary to which is that it is always safe to add nsw/nuw to lowering of affine ops.

Also add a folder to fold div(s|u)i (mul (a, v), v) -> a

@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

Since index operations have no set bitwidth, it is ill-defined to use signed/unsigned wrapping behavior. The corollary to which is that it is always safe to add nsw/nuw to lowering of affine ops.

Also add a folder to fold div(s|u)i (mul (a, v), v) -> a


Patch is 44.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121535.diff

11 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+9-5)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+31-7)
  • (modified) mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir (+4-4)
  • (modified) mlir/test/Conversion/AffineToStandard/lower-affine.mlir (+48-48)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+14-14)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+62)
  • (modified) mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir (+2-2)
  • (modified) mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir (+2-2)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+4-4)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir (+5-5)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+4-4)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d3ead20fb5cd3..8328b6430e182b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -56,7 +56,9 @@ class AffineApplyExpander
     auto rhs = visit(expr.getRHS());
     if (!lhs || !rhs)
       return nullptr;
-    auto op = builder.create<OpTy>(loc, lhs, rhs);
+    auto op = builder.create<OpTy>(loc, lhs, rhs,
+                                   arith::IntegerOverflowFlags::nsw |
+                                       arith::IntegerOverflowFlags::nuw);
     return op.getResult();
   }
 
@@ -93,8 +95,9 @@ class AffineApplyExpander
     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
     Value isRemainderNegative = builder.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::slt, remainder, zeroCst);
-    Value correctedRemainder =
-        builder.create<arith::AddIOp>(loc, remainder, rhs);
+    Value correctedRemainder = builder.create<arith::AddIOp>(
+        loc, remainder, rhs,
+        arith::IntegerOverflowFlags::nsw | arith::IntegerOverflowFlags::nuw);
     Value result = builder.create<arith::SelectOp>(
         loc, isRemainderNegative, correctedRemainder, remainder);
     return result;
@@ -178,8 +181,9 @@ class AffineApplyExpander
     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
     Value negatedQuotient =
         builder.create<arith::SubIOp>(loc, zeroCst, quotient);
-    Value incrementedQuotient =
-        builder.create<arith::AddIOp>(loc, quotient, oneCst);
+    Value incrementedQuotient = builder.create<arith::AddIOp>(
+        loc, quotient, oneCst,
+        arith::IntegerOverflowFlags::nsw | arith::IntegerOverflowFlags::nuw);
     Value result = builder.create<arith::SelectOp>(
         loc, nonPositive, negatedQuotient, incrementedQuotient);
     return result;
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43c0..c43f182f40947b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -596,6 +596,18 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
                                                  return a.udiv(b);
                                                });
 
+  // divui (muli (a, v), v) -> a
+  if (auto muliOp = getLhs().getDefiningOp<arith::MulIOp>()) {
+    if (muliOp.hasNoUnsignedWrap()) {
+      if (getRhs() == muliOp.getRhs()) {
+        return muliOp.getLhs();
+      }
+      if (getRhs() == muliOp.getLhs()) {
+        return muliOp.getRhs();
+      }
+    }
+  }
+
   return div0 ? Attribute() : result;
 }
 
@@ -632,6 +644,18 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
         return a.sdiv_ov(b, overflowOrDiv0);
       });
 
+  // divsi (muli (a, v), v) -> a
+  if (auto muliOp = getLhs().getDefiningOp<arith::MulIOp>()) {
+    if (muliOp.hasNoSignedWrap()) {
+      if (getRhs() == muliOp.getRhs()) {
+        return muliOp.getLhs();
+      }
+      if (getRhs() == muliOp.getLhs()) {
+        return muliOp.getRhs();
+      }
+    }
+  }
+
   return overflowOrDiv0 ? Attribute() : result;
 }
 
@@ -2341,12 +2365,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
 
   // Constant-fold constant operands over non-splat constant condition.
   // select %cst_vec, %cst0, %cst1 => %cst2
-  if (auto cond =
-          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
-    if (auto lhs =
-            llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
-      if (auto rhs =
-              llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
+  if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
+          adaptor.getCondition())) {
+    if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+            adaptor.getTrueValue())) {
+      if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+              adaptor.getFalseValue())) {
         SmallVector<Attribute> results;
         results.reserve(static_cast<size_t>(cond.getNumElements()));
         auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2614,7 +2638,7 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
     return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::minimumf:
     return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
-   case AtomicRMWKind::maxnumf:
+  case AtomicRMWKind::maxnumf:
     return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::minnumf:
     return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
index 58580a194df0c7..e9915ea550a5d8 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir
@@ -9,7 +9,7 @@ func.func @affine_vector_load(%arg0 : index) {
 // CHECK:       %[[buf:.*]] = memref.alloc
 // CHECK:       %[[a:.*]] = arith.addi %{{.*}}, %{{.*}} : index
 // CHECK-NEXT:  %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  vector.load %[[buf]][%[[b]]] : memref<100xf32>, vector<8xf32>
   return
 }
@@ -26,10 +26,10 @@ func.func @affine_vector_store(%arg0 : index) {
 // CHECK:       %[[buf:.*]] = memref.alloc
 // CHECK:       %[[val:.*]] = arith.constant dense
 // CHECK:       %[[c_1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:  %[[a:.*]] = arith.muli %arg0, %[[c_1]] : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %{{.*}}, %[[a]] : index
+// CHECK-NEXT:  %[[a:.*]] = arith.muli %arg0, %[[c_1]] overflow<nsw, nuw> : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %{{.*}}, %[[a]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT:  %[[c:.*]] = arith.addi %[[b]], %[[c7]] : index
+// CHECK-NEXT:  %[[c:.*]] = arith.addi %[[b]], %[[c7]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  vector.store %[[val]], %[[buf]][%[[c]]] : memref<100xf32>, vector<4xf32>
   return
 }
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 00d7b6b8d65f67..e21d20995a7d01 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -156,9 +156,9 @@ func.func private @get_idx() -> (index)
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[c20:.*]] = arith.constant 20 : index
-// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
+// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
 // CHECK-NEXT:   if %[[v3]] {
 // CHECK-NEXT:     call @body(%[[v0:.*]]) : (index) -> ()
@@ -177,9 +177,9 @@ func.func @if_only() {
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[c20:.*]] = arith.constant 20 : index
-// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
+// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
 // CHECK-NEXT:   if %[[v3]] {
 // CHECK-NEXT:     call @body(%[[v0:.*]]) : (index) -> ()
@@ -202,14 +202,14 @@ func.func @if_else() {
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[c20:.*]] = arith.constant 20 : index
-// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
+// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
 // CHECK-NEXT:   if %[[v3]] {
 // CHECK-NEXT:     %[[c0_0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:     %[[cm10:.*]] = arith.constant -10 : index
-// CHECK-NEXT:     %[[v4:.*]] = arith.addi %[[v0]], %[[cm10]] : index
+// CHECK-NEXT:     %[[v4:.*]] = arith.addi %[[v0]], %[[cm10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:     %[[v5:.*]] = arith.cmpi sge, %[[v4]], %[[c0_0]] : index
 // CHECK-NEXT:     if %[[v5]] {
 // CHECK-NEXT:       call @body(%[[v0:.*]]) : (index) -> ()
@@ -217,7 +217,7 @@ func.func @if_else() {
 // CHECK-NEXT:   } else {
 // CHECK-NEXT:     %[[c0_0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:     %[[cm10:.*]] = arith.constant -10 : index
-// CHECK-NEXT:     %{{.*}} = arith.addi %[[v0]], %[[cm10]] : index
+// CHECK-NEXT:     %{{.*}} = arith.addi %[[v0]], %[[cm10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:     %{{.*}} = arith.cmpi sge, %{{.*}}, %[[c0_0]] : index
 // CHECK-NEXT:     if %{{.*}} {
 // CHECK-NEXT:       call @mid(%[[v0:.*]]) : (index) -> ()
@@ -245,7 +245,7 @@ func.func @nested_ifs() {
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm10:.*]] = arith.constant -10 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.addi %[[v0]], %[[cm10]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.addi %[[v0]], %[[cm10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v2:.*]] = arith.cmpi sge, %[[v1]], %[[c0]] : index
 // CHECK-NEXT:   %[[v3:.*]] = scf.if %[[v2]] -> (i64) {
 // CHECK-NEXT:     scf.yield %[[c0_i64]] : i64
@@ -272,25 +272,25 @@ func.func @if_with_yield() -> (i64) {
 // CHECK-NEXT:   %[[v0:.*]] = call @get_idx() : () -> index
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %{{.*}} : index
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[v3:.*]] = arith.addi %[[v2]], %[[c1]] : index
+// CHECK-NEXT:   %[[v3:.*]] = arith.addi %[[v2]], %[[c1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v4:.*]] = arith.cmpi sge, %[[v3]], %[[c0]] : index
 // CHECK-NEXT:   %[[cm1_0:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v5:.*]] = arith.addi %{{.*}}, %[[cm1_0]] : index
+// CHECK-NEXT:   %[[v5:.*]] = arith.addi %{{.*}}, %[[cm1_0]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v6:.*]] = arith.cmpi sge, %[[v5]], %[[c0]] : index
 // CHECK-NEXT:   %[[v7:.*]] = arith.andi %[[v4]], %[[v6]] : i1
 // CHECK-NEXT:   %[[cm1_1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v8:.*]] = arith.addi %{{.*}}, %[[cm1_1]] : index
+// CHECK-NEXT:   %[[v8:.*]] = arith.addi %{{.*}}, %[[cm1_1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v9:.*]] = arith.cmpi sge, %[[v8]], %[[c0]] : index
 // CHECK-NEXT:   %[[v10:.*]] = arith.andi %[[v7]], %[[v9]] : i1
 // CHECK-NEXT:   %[[cm1_2:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v11:.*]] = arith.addi %{{.*}}, %[[cm1_2]] : index
+// CHECK-NEXT:   %[[v11:.*]] = arith.addi %{{.*}}, %[[cm1_2]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v12:.*]] = arith.cmpi sge, %[[v11]], %[[c0]] : index
 // CHECK-NEXT:   %[[v13:.*]] = arith.andi %[[v10]], %[[v12]] : i1
 // CHECK-NEXT:   %[[cm42:.*]] = arith.constant -42 : index
-// CHECK-NEXT:   %[[v14:.*]] = arith.addi %{{.*}}, %[[cm42]] : index
+// CHECK-NEXT:   %[[v14:.*]] = arith.addi %{{.*}}, %[[cm42]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v15:.*]] = arith.cmpi eq, %[[v14]], %[[c0]] : index
 // CHECK-NEXT:   %[[v16:.*]] = arith.andi %[[v13]], %[[v15]] : i1
 // CHECK-NEXT:   if %[[v16]] {
@@ -316,9 +316,9 @@ func.func @if_for() {
   %i = call @get_idx() : () -> (index)
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT:   %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] : index
+// CHECK-NEXT:   %[[v1:.*]] = arith.muli %[[v0]], %[[cm1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[c20:.*]] = arith.constant 20 : index
-// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] : index
+// CHECK-NEXT:   %[[v2:.*]] = arith.addi %[[v1]], %[[c20]] overflow<nsw, nuw> : index
 // CHECK-NEXT:   %[[v3:.*]] = arith.cmpi sge, %[[v2]], %[[c0]] : index
 // CHECK-NEXT:   if %[[v3]] {
 // CHECK-NEXT:     %[[c0:.*]]{{.*}} = arith.constant 0 : index
@@ -327,7 +327,7 @@ func.func @if_for() {
 // CHECK-NEXT:     for %{{.*}} = %[[c0:.*]]{{.*}} to %[[c42:.*]]{{.*}} step %[[c1:.*]]{{.*}} {
 // CHECK-NEXT:       %[[c0_:.*]]{{.*}} = arith.constant 0 : index
 // CHECK-NEXT:       %[[cm10:.*]] = arith.constant -10 : index
-// CHECK-NEXT:       %[[v4:.*]] = arith.addi %{{.*}}, %[[cm10]] : index
+// CHECK-NEXT:       %[[v4:.*]] = arith.addi %{{.*}}, %[[cm10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:       %[[v5:.*]] = arith.cmpi sge, %[[v4]], %[[c0_:.*]]{{.*}} : index
 // CHECK-NEXT:       if %[[v5]] {
 // CHECK-NEXT:         call @body2(%[[v0]], %{{.*}}) : (index, index) -> ()
@@ -371,11 +371,11 @@ func.func @if_for() {
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
 // CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
 // CHECK-NEXT:     %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:     %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
-// CHECK-NEXT:     %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
+// CHECK-NEXT:     %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] overflow<nsw, nuw> : index
+// CHECK-NEXT:     %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} overflow<nsw, nuw> : index
 // CHECK-NEXT:     %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
 // CHECK-NEXT:     %[[c10:.*]] = arith.constant 10 : index
-// CHECK-NEXT:     %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] : index
+// CHECK-NEXT:     %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] overflow<nsw, nuw> : index
 // CHECK-NEXT:     %[[min:.*]] = arith.minsi %{{.*}}, %[[add1]] : index
 // CHECK-NEXT:     %[[c1_0:.*]] = arith.constant 1 : index
 // CHECK-NEXT:     for %{{.*}} = %[[max]] to %[[min]] step %[[c1_0]] {
@@ -442,29 +442,29 @@ func.func @affine_applies(%arg0 : index) {
   %102 = arith.constant 102 : index
   %copy = affine.apply #map2(%zero)
 
-// CHECK-NEXT: %[[v0:.*]] = arith.addi %[[c0]], %[[c0]] : index
+// CHECK-NEXT: %[[v0:.*]] = arith.addi %[[c0]], %[[c0]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[v0]], %[[c1]] : index
+// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[v0]], %[[c1]] overflow<nsw, nuw> : index
   %one = affine.apply #map3(%symbZero)[%zero]
 
 // CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
-// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] : index
-// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index
+// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] : index
-// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index
+// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] : index
-// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index
+// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index
-// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] : index
-// CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index
+// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index
-// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] : index
-// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index
+// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] : index
-// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
+// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow<nsw, nuw> : index
+// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] overflow<nsw, nuw> : index
   %four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
   return
 }
@@ -498,7 +498,7 @@ func.func @affine_apply_mod(%arg0 : index) -> (index) {
 // CHECK-NEXT: %[[v0:.*]] = arith.remsi %{{.*}}, %[[c42]] : index
 // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT: %[[v1:.*]] = arith.cmpi slt, %[[v0]], %[[c0]] : index
-// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %[[c42]] : index
+// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %[[c42]] overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[v3:.*]] = arith.select %[[v1]], %[[v2]], %[[v0]] : index
   %0 = affine.apply #map_mod (%arg0)
   return %0 : index
@@ -509,7 +509,7 @@ func.func @affine_apply_mod_dynamic_divisor(%arg0 : index, %arg1 : index) -> (in
 // CHECK-NEXT: %[[v0:.*]] = arith.remsi %{{.*}}, %arg1 : index
 // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
 // CHECK-NEXT: %[[v1:.*]] = arith.cmpi slt, %[[v0]], %[[c0]] : index
-// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %arg1 : index
+// CHECK-NEXT: %[[v2:.*]] = arith.addi %[[v0]], %arg1 overflow<nsw, nuw> : index
 // CHECK-NEXT: %[[v3:.*]] = arith.select %[[v1]], %[[v2]], %[[v0]] : index
   %0 = affine.apply #map_mod_dynamic_divisor (%arg0)[%arg1]
   return %0 : index
@@ -567,7 +567,7 @@ func.func @affine_apply_ceildiv(%arg0 : index) -> (index) {
 // CHECK-NEXT:  %[[v3:.*]] = arith.select %[[v0]], %[[v1]], %[[v2]] : index
 // CHECK-NEXT:  %[[v4:.*]] = arith.divsi %[[v3]], %[[c42]] : index
 // CHECK-NEXT:  %[[v5:.*]] = arith.subi %[[c0]], %[[v4]] : index
-// CHECK-NEXT:  %[[v6:.*]] = arith.addi %[[v4]], %[[c1]] : index
+// CHECK-NEXT:  %[[v6:.*]] = arith.addi %[[v4]], %[[c1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %[[v7:.*]] = arith.select %[[v0]], %[[v5]], %[[v6]] : index
   %0 = affine.apply #map_ceildiv (%arg0)
   return %0 : index
@@ -583,7 +583,7 @@ func.func @affine_apply_ceildiv_dynamic_divisor(%arg0 : index, %arg1 : index) ->
 // CHECK-NEXT:  %[[v3:.*]] = arith.select %[[v0]], %[[v1]], %[[v2]] : index
 // CHECK-NEXT:  %[[v4:.*]] = arith.divsi %[[v3]], %arg1 : index
 // CHECK-NEXT:  %[[v5:.*]] = arith.subi %[[c0]], %[[v4]] : index
-// CHECK-NEXT:  %[[v6:.*]] = arith.addi %[[v4]], %[[c1]] : index
+// CHECK-NEXT:  %[[v6:.*]] = arith.addi %[[v4]], %[[c1]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %[[v7:.*]] = arith.select %[[v0]], %[[v5]], %[[v6]] : index
   %0 = affine.apply #map_ceildiv_dynamic_divisor (%arg0)[%arg1]
   return %0 : index
@@ -597,7 +597,7 @@ func.func @affine_load(%arg0 : index) {
   }
 // CHECK:       %[[a:.*]] = arith.addi %{{.*}}, %{{.*}} : index
 // CHECK-NEXT:  %[[c7:.*]] = arith.constant 7 : index
-// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] : index
+// CHECK-NEXT:  %[[b:.*]] = arith.addi %[[a]], %[[c7]] overflow<nsw, nuw> : index
 // CHECK-NEXT:  %{{.*}} = memref.load %[[v0:.*]][%[[b]]] : memref<10xf32>
   return
 }
@@ -610,10 +61...
[truncated]

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly, I don't think this change is correct. If it is correct.

That being said, I think that the broader point about overflows has something to it. I claim that all the additions and multiplications in an affine map (and so also in AffineExpandIndexOps.cpp and the other utilities) are nsw - assuming we document this - because the result of any operation is guaranteed not to wrap index in a signed way (for whatever the eventual width of index is)

Now, nsw && [proof of non-negativity] => nuw, but I don't think that can be done at this stage of codegen.

@@ -93,8 +95,9 @@ class AffineApplyExpander
Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
Value isRemainderNegative = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zeroCst);
Value correctedRemainder =
builder.create<arith::AddIOp>(loc, remainder, rhs);
Value correctedRemainder = builder.create<arith::AddIOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure this addition is not nuw.

Specifically, consider lhs = -5 and rhs = 3. Then, remainder == -2, which means that we're doing correctedRemainder = add nuw nsw -2, 3, which is not nuw

@@ -178,8 +181,9 @@ class AffineApplyExpander
Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
Value negatedQuotient =
builder.create<arith::SubIOp>(loc, zeroCst, quotient);
Value incrementedQuotient =
builder.create<arith::AddIOp>(loc, quotient, oneCst);
Value incrementedQuotient = builder.create<arith::AddIOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have the detailed proof on me, but I suspect that this is going to fail on quotient = -1

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 3, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 3, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 3, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 3, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 3, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
@MaheshRavishankar MaheshRavishankar force-pushed the fold_div_of_mul branch 2 times, most recently from 7b6a4e1 to 13f06e4 Compare January 4, 2025 01:08
@MaheshRavishankar
Copy link
Contributor Author

Sadly, I don't think this change is correct. If it is correct.

That being said, I think that the broader point about overflows has something to it. I claim that all the additions and multiplications in an affine map (and so also in AffineExpandIndexOps.cpp and the other utilities) are nsw - assuming we document this - because the result of any operation is guaranteed not to wrap index in a signed way (for whatever the eventual width of index is)

Now, nsw && [proof of non-negativity] => nuw, but I don't think that can be done at this stage of codegen.

Fair point... I am actually these changes a bit... for now I am just going to add nsw to muli operations and come back and add nsw to addi operations

@MaheshRavishankar MaheshRavishankar changed the title [mlir][Affine] Add nuw/nsw to lowering of affine ops. [mlir][Affine] Add nsw to lowering of AffineMulExpr. Jan 4, 2025
Since index operations have no set bitwidth, it is ill-defined to use
signed/unsigned wrapping behavior. The corollary to which is that it
is always safe to add nsw/nuw to lowering of affine ops.

Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 4, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 4, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

I suspect we can get add in a followup PR, as well as all the arith.addi that get incidentally produced by the lowering

@MaheshRavishankar
Copy link
Contributor Author

Approved

I suspect we can get add in a followup PR, as well as all the arith.addi that get incidentally produced by the lowering

Yes, Ill do that as a follow up.. tracking some hairy issues with this downstream (unrelated to this change, just some issue in downstream). So staging this a bit.

@MaheshRavishankar MaheshRavishankar merged commit 8cd94e0 into llvm:main Jan 6, 2025
8 checks passed
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to iree-org/iree that referenced this pull request Jan 6, 2025
Carries 4 reverts

Related to Nanobind issues

llvm/llvm-project@5cd4274
llvm/llvm-project@08e2c15
llvm/llvm-project@b56d1ec

Related to RISC-V compilation

llvm/llvm-project@169c32e

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
MaheshRavishankar added a commit to iree-org/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to iree-org/iree that referenced this pull request Jan 7, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
MaheshRavishankar added a commit to iree-org/iree that referenced this pull request Jan 7, 2025
Signed-off-by: MaheshRavishankar <mravisha@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants