Skip to content

[mlir] Add affine.delinearize_index and affine.linearize_index ValueBoundsOpInterfaceImpl #118829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Dec 5, 2024

Adds a ValueBoundsOpInterface implementation for affine.linearize_index and affine.delinearize index. The implementations are effectively special cases of the affine.apply op implementation. An affine expression for the given result is formed from the constrained expressions of the op operands, and the result is constrained to be equal to this affine expression.

…d affine.linearize_index

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2024

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

Adds a ValueBoundsOpInterface implementation for affine.linearize_index and affine.delinearize index. The implementations are effectively special cases of the affine.apply op implementation. An affine expression for the given result is formed from the constrained expressions of the op operands, and the result is constrained to be equal to this affine expression.


Full diff: https://github.com/llvm/llvm-project/pull/118829.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp (+65)
  • (modified) mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir (+37)
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index 82a9fb0d490882..77107fb894bb0a 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -49,6 +49,67 @@ struct AffineApplyOpInterface
   }
 };
 
+struct AffineDelinearizeIndexOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<
+          AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto delinearizeOp = cast<AffineDelinearizeIndexOp>(op);
+    auto result = cast<OpResult>(value);
+    int64_t resultIdx = result.getResultNumber();
+    assert(result.getOwner() == delinearizeOp && "invalid value");
+
+    AffineExpr linearIdxExpr = cstr.getExpr(delinearizeOp.getLinearIndex());
+    SmallVector<OpFoldResult> basis = delinearizeOp.getMixedBasis();
+    SmallVector<AffineExpr> basisExprs;
+    AffineExpr modExpr = getAffineConstantExpr(1, op->getContext());
+    AffineExpr strideExpr = getAffineConstantExpr(1, op->getContext());
+    for (int i = basis.size() - 1; i >= resultIdx; --i) {
+      AffineExpr basisExpr = cstr.getExpr(basis[i]);
+      modExpr = modExpr * basisExpr;
+      if (i > resultIdx)
+        strideExpr = strideExpr * basisExpr;
+    }
+    AffineExpr bound = linearIdxExpr;
+    if (resultIdx > 0)
+      bound = bound % modExpr;
+    if (resultIdx < delinearizeOp->getNumResults())
+      bound = bound.floorDiv(strideExpr);
+
+    cstr.bound(value) == bound;
+  }
+};
+
+struct AffineLinearizeIndexOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<
+          AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto linearizeOp = cast<AffineLinearizeIndexOp>(op);
+    assert(value == linearizeOp.getResult() && "invalid value");
+
+    SmallVector<OpFoldResult> basis = linearizeOp.getMixedBasis();
+    SmallVector<AffineExpr> basisExprs = llvm::map_to_vector(
+        basis, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); });
+    basisExprs.push_back(getAffineConstantExpr(1, op->getContext()));
+
+    SmallVector<OpFoldResult> indices(linearizeOp.getMultiIndex());
+    SmallVector<AffineExpr> indexExprs = llvm::map_to_vector(
+        indices, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); });
+
+    AffineExpr linearIdxExpr = getAffineConstantExpr(0, op->getContext());
+    AffineExpr strideExpr = getAffineConstantExpr(1, op->getContext());
+    std::reverse(indexExprs.begin(), indexExprs.end());
+    std::reverse(basisExprs.begin(), basisExprs.end());
+    for (size_t i = 0; i < indexExprs.size(); ++i) {
+      strideExpr = strideExpr * basisExprs[i];
+      linearIdxExpr = linearIdxExpr + indexExprs[i] * strideExpr;
+    }
+
+    cstr.bound(value) == linearIdxExpr;
+  }
+};
+
 struct AffineMinOpInterface
     : public ValueBoundsOpInterface::ExternalModel<AffineMinOpInterface,
                                                    AffineMinOp> {
@@ -98,6 +159,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) {
     AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
+    AffineDelinearizeIndexOp::attachInterface<
+        AffineDelinearizeIndexOpInterface>(*ctx);
+    AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>(
+        *ctx);
     AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
     AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
   });
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 935c08aceff548..2184d7fa5074e8 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -155,3 +155,40 @@ func.func @compare_maps(%a: index, %b: index) {
       : (index, index, index, index) -> ()
   return
 }
+
+// -----
+
+func.func @compare_affine_linearize_index(%a: index, %b: index) {
+  %0 = affine.linearize_index [%a, %b] by (2, 4) : index
+  %1 = affine.linearize_index [%a, %b] by (4) : index
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %a, %b) {rhs_map = affine_map<(a, b) -> (a * 4 + b)>}
+      : (index, index, index) -> ()
+  // expected-remark @below{{true}}
+  "test.compare"(%1, %a, %b) {rhs_map = affine_map<(a, b) -> (a * 4 + b)>}
+      : (index, index, index) -> ()
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 4)>
+
+// CHECK-LABEL: func @affine_delinearize_index(
+//  CHECK-SAME:   %[[a:.*]]: index
+func.func @affine_delinearize_index(%a: index) -> (index, index, index, index) {
+  %0:2 = affine.delinearize_index %a into (2, 4) : index, index
+  %1:2 = affine.delinearize_index %a into (4) : index, index
+
+  // CHECK: %[[BOUND0:.+]] = affine.apply #[[$MAP]]()[%[[a]]]
+  %2 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
+  // CHECK: %[[BOUND1:.+]] = affine.apply #[[$MAP1]]()[%[[a]]]
+  %3 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
+  // CHECK: %[[BOUND2:.+]] = affine.apply #[[$MAP]]()[%[[a]]]
+  %4 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
+  // CHECK: %[[BOUND3:.+]] = affine.apply #[[$MAP1]]()[%[[a]]]
+  %5 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
+  // CHECK: return %[[BOUND0]], %[[BOUND1]], %[[BOUND2]], %[[BOUND3]]
+  return %2, %3, %4, %5: index, index, index, index
+}

@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2024

@llvm/pr-subscribers-mlir-affine

Author: None (Max191)

Changes

Adds a ValueBoundsOpInterface implementation for affine.linearize_index and affine.delinearize index. The implementations are effectively special cases of the affine.apply op implementation. An affine expression for the given result is formed from the constrained expressions of the op operands, and the result is constrained to be equal to this affine expression.


Full diff: https://github.com/llvm/llvm-project/pull/118829.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp (+65)
  • (modified) mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir (+37)
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index 82a9fb0d490882..77107fb894bb0a 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -49,6 +49,67 @@ struct AffineApplyOpInterface
   }
 };
 
+struct AffineDelinearizeIndexOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<
+          AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto delinearizeOp = cast<AffineDelinearizeIndexOp>(op);
+    auto result = cast<OpResult>(value);
+    int64_t resultIdx = result.getResultNumber();
+    assert(result.getOwner() == delinearizeOp && "invalid value");
+
+    AffineExpr linearIdxExpr = cstr.getExpr(delinearizeOp.getLinearIndex());
+    SmallVector<OpFoldResult> basis = delinearizeOp.getMixedBasis();
+    SmallVector<AffineExpr> basisExprs;
+    AffineExpr modExpr = getAffineConstantExpr(1, op->getContext());
+    AffineExpr strideExpr = getAffineConstantExpr(1, op->getContext());
+    for (int i = basis.size() - 1; i >= resultIdx; --i) {
+      AffineExpr basisExpr = cstr.getExpr(basis[i]);
+      modExpr = modExpr * basisExpr;
+      if (i > resultIdx)
+        strideExpr = strideExpr * basisExpr;
+    }
+    AffineExpr bound = linearIdxExpr;
+    if (resultIdx > 0)
+      bound = bound % modExpr;
+    if (resultIdx < delinearizeOp->getNumResults())
+      bound = bound.floorDiv(strideExpr);
+
+    cstr.bound(value) == bound;
+  }
+};
+
+struct AffineLinearizeIndexOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<
+          AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto linearizeOp = cast<AffineLinearizeIndexOp>(op);
+    assert(value == linearizeOp.getResult() && "invalid value");
+
+    SmallVector<OpFoldResult> basis = linearizeOp.getMixedBasis();
+    SmallVector<AffineExpr> basisExprs = llvm::map_to_vector(
+        basis, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); });
+    basisExprs.push_back(getAffineConstantExpr(1, op->getContext()));
+
+    SmallVector<OpFoldResult> indices(linearizeOp.getMultiIndex());
+    SmallVector<AffineExpr> indexExprs = llvm::map_to_vector(
+        indices, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); });
+
+    AffineExpr linearIdxExpr = getAffineConstantExpr(0, op->getContext());
+    AffineExpr strideExpr = getAffineConstantExpr(1, op->getContext());
+    std::reverse(indexExprs.begin(), indexExprs.end());
+    std::reverse(basisExprs.begin(), basisExprs.end());
+    for (size_t i = 0; i < indexExprs.size(); ++i) {
+      strideExpr = strideExpr * basisExprs[i];
+      linearIdxExpr = linearIdxExpr + indexExprs[i] * strideExpr;
+    }
+
+    cstr.bound(value) == linearIdxExpr;
+  }
+};
+
 struct AffineMinOpInterface
     : public ValueBoundsOpInterface::ExternalModel<AffineMinOpInterface,
                                                    AffineMinOp> {
@@ -98,6 +159,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) {
     AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
+    AffineDelinearizeIndexOp::attachInterface<
+        AffineDelinearizeIndexOpInterface>(*ctx);
+    AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>(
+        *ctx);
     AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
     AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
   });
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 935c08aceff548..2184d7fa5074e8 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -155,3 +155,40 @@ func.func @compare_maps(%a: index, %b: index) {
       : (index, index, index, index) -> ()
   return
 }
+
+// -----
+
+func.func @compare_affine_linearize_index(%a: index, %b: index) {
+  %0 = affine.linearize_index [%a, %b] by (2, 4) : index
+  %1 = affine.linearize_index [%a, %b] by (4) : index
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %a, %b) {rhs_map = affine_map<(a, b) -> (a * 4 + b)>}
+      : (index, index, index) -> ()
+  // expected-remark @below{{true}}
+  "test.compare"(%1, %a, %b) {rhs_map = affine_map<(a, b) -> (a * 4 + b)>}
+      : (index, index, index) -> ()
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 4)>
+
+// CHECK-LABEL: func @affine_delinearize_index(
+//  CHECK-SAME:   %[[a:.*]]: index
+func.func @affine_delinearize_index(%a: index) -> (index, index, index, index) {
+  %0:2 = affine.delinearize_index %a into (2, 4) : index, index
+  %1:2 = affine.delinearize_index %a into (4) : index, index
+
+  // CHECK: %[[BOUND0:.+]] = affine.apply #[[$MAP]]()[%[[a]]]
+  %2 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
+  // CHECK: %[[BOUND1:.+]] = affine.apply #[[$MAP1]]()[%[[a]]]
+  %3 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
+  // CHECK: %[[BOUND2:.+]] = affine.apply #[[$MAP]]()[%[[a]]]
+  %4 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
+  // CHECK: %[[BOUND3:.+]] = affine.apply #[[$MAP1]]()[%[[a]]]
+  %5 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
+  // CHECK: return %[[BOUND0]], %[[BOUND1]], %[[BOUND2]], %[[BOUND3]]
+  return %2, %3, %4, %5: index, index, index, index
+}

// CHECK: %[[BOUND2:.+]] = affine.apply #[[$MAP]]()[%[[a]]]
%4 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
// CHECK: %[[BOUND3:.+]] = affine.apply #[[$MAP1]]()[%[[a]]]
%5 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is testing the bounds of %0 twice, did you mean %1 here?

if (i > resultIdx)
strideExpr = strideExpr * basisExpr;
}
AffineExpr bound = linearIdxExpr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: A comment describing the expression would be nice


SmallVector<OpFoldResult> indices(linearizeOp.getMultiIndex());
SmallVector<AffineExpr> indexExprs = llvm::map_to_vector(
indices, [&](OpFoldResult ofr) { return cstr.getExpr(ofr); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of creating the vector and then reversing, you can do llvm::map_to_vector(llvm::reverse(indices), ... here and above.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, seems fine to me, aside from the note about the test

for (int i = basis.size() - 1; i >= resultIdx; --i) {
AffineExpr basisExpr = cstr.getExpr(basis[i]);
modExpr = modExpr * basisExpr;
if (i > resultIdx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like an equivalent way to phrase this is to walk up to the index before resultIdx, take the product of those, and then go one step further to get the final mod.

However, I'm not going to insist too much on that

@krzysz00
Copy link
Contributor

krzysz00 commented Jan 6, 2025

... looks like I wrote up an independent implementation of this in #121833 - can we check if these are equivalent?

@Max191 Max191 closed this Jan 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants