Skip to content

[mlir][scf]: Add value bound between scf for loop yield and result #123200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 19, 2025

Conversation

AviadCo
Copy link
Contributor

@AviadCo AviadCo commented Jan 16, 2025

We can prove that:
%result == %init_arg + trip_count * (%yielded_value - %iter_arg). Where trip_count is (ub - lb) / step.

@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: Aviad Cohen (AviadCo)

Changes

We can prove that:
%result == %init_arg + trip_count * (%yielded_value - %iter_arg). Where trip_count is (ub - lb) / step.


Full diff: https://github.com/llvm/llvm-project/pull/123200.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+16)
  • (modified) mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir (+67)
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index fbd236b648cb8a..8a27bf186d1c2a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -70,6 +70,22 @@ struct ForOpInterface
         cstr.bound(value) == cstr.getExpr(initArg);
       }
     }
+
+    if (dim.has_value() || isa<BlockArgument>(value))
+      return;
+
+    // `value` is result of `forOp`, we can prove that:
+    // %result == %init_arg + trip_count * (%yielded_value - %iter_arg).
+    // Where trip_count is (ub - lb) / step.
+    AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound());
+    AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound());
+    AffineExpr stepExpr = cstr.getExpr(forOp.getStep());
+    AffineExpr tripCountExpr =
+        AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step
+    AffineExpr oneIterAdvanceExpr =
+        cstr.getExpr(yieldedValue) - cstr.getExpr(iterArg);
+    cstr.bound(value) ==
+        cstr.getExpr(initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr);
   }
 
   void populateBoundsForIndexValue(Operation *op, Value value,
diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
index 6e0c16a9a2b33f..c636fb7ae072ff 100644
--- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
@@ -267,3 +267,70 @@ func.func @compare_scf_for(%a: index, %b: index, %c: index) {
   }
   return
 }
+
+// -----
+
+func.func @scf_for_result_infer() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %c0) -> index {
+    %2 = "test.some_use"() : () -> (i1)
+    %3 = scf.if %2 -> (index) {
+        %5 = arith.addi %arg, %c1 : index
+        scf.yield %5 : index
+    } else {
+        scf.yield %arg : index
+    }
+    scf.yield %3 : index
+  }
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %c10) {cmp = "LE"} : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @scf_for_result_infer_dynamic_init(%i : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %i) -> index {
+    %2 = "test.some_use"() : () -> (i1)
+    %3 = scf.if %2 -> (index) {
+        %5 = arith.addi %arg, %c1 : index
+        scf.yield %5 : index
+    } else {
+        scf.yield %arg : index
+    }
+    scf.yield %3 : index
+  }
+  %6 = arith.addi %i, %c10 : index
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @scf_for_result_infer_dynamic_init_big_step(%i : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c5 = arith.constant 5 : index
+  %c10 = arith.constant 10 : index
+  %0 = scf.for %iv = %c0 to %c10 step %c2 iter_args(%arg = %i) -> index {
+    %2 = "test.some_use"() : () -> (i1)
+    %3 = scf.if %2 -> (index) {
+        %5 = arith.addi %arg, %c1 : index
+        scf.yield %5 : index
+    } else {
+        scf.yield %arg : index
+    }
+    scf.yield %3 : index
+  }
+  %6 = arith.addi %i, %c5 : index
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> ()
+  return
+}

We can prove that:
%result == %init_arg + trip_count * (%yielded_value - %iter_arg).
Where trip_count is (ub - lb) / step.
@AviadCo AviadCo merged commit f8b2794 into llvm:main Jan 19, 2025
8 checks passed
@AviadCo AviadCo deleted the scf/forOpBounds branch January 19, 2025 06:52
peterbell10 pushed a commit to triton-lang/triton that referenced this pull request Jan 29, 2025
makslevental added a commit to makslevental/triton that referenced this pull request Feb 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants