Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] recognize ReLu operation during sparsification #92016

Merged
merged 2 commits into from
May 13, 2024

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented May 13, 2024

This is a proof of concept recognition of the most basic forms of ReLu operations, used to show-case sparsification of end-to-end PyTorch models. In the long run, we must avoid lowering such constructs too early (with this need for raising them back).

See discussion at
https://discourse.llvm.org/t/min-max-abs-relu-recognition-starter-project/78918

This is a proof of concept recognition of the most basic forms
of ReLu operations, used to show-case sparsification of end-to-end
PyTorch models. In the long run, we must avoid lowering such
constructs too early (with this need for raising them back).

See discussion at
https://discourse.llvm.org/t/min-max-abs-relu-recognition-starter-project/78918
@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

This is a proof of concept recognition of the most basic forms of ReLu operations, used to show-case sparsification of end-to-end PyTorch models. In the long run, we must avoid lowering such constructs too early (with this need for raising them back).

See discussion at
https://discourse.llvm.org/t/min-max-abs-relu-recognition-starter-project/78918


Full diff: https://github.com/llvm/llvm-project/pull/92016.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+90-10)
  • (added) mlir/test/Dialect/SparseTensor/sparse_relu.mlir (+34)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 7f9820df984b2..b8d278152dc05 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -144,6 +144,7 @@ enum class TensorExp::Kind {
   kExpm1C,
   kLog1pF,
   kLog1pC,
+  kRelu,
   kSinF,
   kSinC,
   kTanhF,
@@ -316,7 +317,7 @@ class Merger {
   /// lattice point on an expression E is simply copied over, but with OP E
   /// as new expression. Returns the identifier of the new set.
   LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(),
-                  Operation *op = nullptr);
+                  Operation *op = nullptr, Attribute attr = nullptr);
 
   /// Maps the binary operator to the same operation but with one of its operand
   /// set to zero, i.e. each lattice point on an expression E is simply copied
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 308fbd965259d..0258f797143cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -44,6 +44,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -104,7 +105,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
 
 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
                      Operation *o, Attribute a)
-    : kind(k), val(v), op(o) {
+    : kind(k), val(v), op(o), attr(a) {
   switch (kind) {
   // Leaf.
   case TensorExp::Kind::kTensor:
@@ -133,6 +134,7 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -201,7 +203,6 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
   case TensorExp::Kind::kCmpF:
   case TensorExp::Kind::kCmpI:
     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
-    attr = a;
     children.e0 = x;
     children.e1 = y;
     return;
@@ -337,7 +338,6 @@ LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
   const LatSetId sNew = conjSet(e, s0, s1, op);
   TensorExp::Kind kind = exp(e).kind;
-
   // Followed by all in s0.
   latSets[sNew].append(latSets[s0]);
   // Map binary 0-y to unary -y.
@@ -381,31 +381,32 @@ LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
                           bool includeLeft, TensorExp::Kind ltrans,
                           Operation *opleft, bool includeRight,
                           TensorExp::Kind rtrans, Operation *opright) {
+  Attribute a = exp(e).attr;
   const LatSetId sNew = conjSet(e, s0, s1, orig);
   // Left Region.
   if (includeLeft) {
     if (opleft)
-      s0 = mapSet(ltrans, s0, Value(), opleft);
+      s0 = mapSet(ltrans, s0, Value(), opleft, a);
     latSets[sNew].append(latSets[s0]);
   }
   // Right Region.
   if (includeRight) {
     if (opright)
-      s1 = mapSet(rtrans, s1, Value(), opright);
+      s1 = mapSet(rtrans, s1, Value(), opright, a);
     latSets[sNew].append(latSets[s1]);
   }
   return sNew;
 }
 
 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
-                        Operation *op) {
+                        Operation *op, Attribute a) {
   assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
          TensorExp::Kind::kDenseOp == kind);
   const LatSetId sNew = addSet();
   auto &setNew = latSets[sNew];
   for (const LatPointId p : set(s0)) {
     const auto &point = latPoints[p];
-    setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
+    setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
   }
   return sNew;
 }
@@ -596,6 +597,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -717,6 +719,8 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
     return "log1p";
+  case TensorExp::Kind::kRelu:
+    return "relu";
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
     return "sin";
@@ -824,6 +828,7 @@ void Merger::dumpExp(ExprId e) const {
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -972,6 +977,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -1001,7 +1007,8 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
     {
       const ExprId e0 = expr.children.e0;
       const Value v = expr.val;
-      return mapSet(kind, buildLattices(e0, i), v);
+      Attribute a = expr.attr;
+      return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
     }
   case TensorExp::Kind::kBinaryBranch:
   case TensorExp::Kind::kSelect:
@@ -1190,10 +1197,26 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
   return buildTensorExp(op, yield->getOperand(0)).first;
 }
 
+/// Only returns true if we are certain this is a zero.
+static bool isCertainZero(Value val) {
+  if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
+    ArrayAttr arrayAttr = c.getValue();
+    return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
+           cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
+  }
+  if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
+    return c.value() == 0;
+  if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
+    return c.value().isZero();
+  return false;
+}
+
 /// Only returns false if we are certain this is a nonzero.
 bool Merger::maybeZero(ExprId e) const {
   const auto &expr = exp(e);
   if (expr.kind == TensorExp::Kind::kInvariant) {
+    // Note that this is different from isCertainZero() in a subtle
+    // way by always returning true for non-constants.
     if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
       ArrayAttr arrayAttr = c.getValue();
       return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
@@ -1247,6 +1270,21 @@ static bool isAdmissibleBranch(Operation *op, Region &region) {
   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
 }
 
+// Recognizes a direct GT comparison.
+static bool isGreater(TensorExp::Kind kind, Attribute attr) {
+  if (kind == TensorExp::Kind::kCmpI) {
+    auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
+    return pred == arith::CmpIPredicate::ugt ||
+           pred == arith::CmpIPredicate::sgt;
+  }
+  if (kind == TensorExp::Kind::kCmpF) {
+    auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
+    return pred == arith::CmpFPredicate::UGT ||
+           pred == arith::CmpFPredicate::OGT;
+  }
+  return false;
+}
+
 std::pair<std::optional<ExprId>, bool>
 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   // Recursion leaves.
@@ -1266,6 +1304,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
     // or belonging to an enveloping op) is considered invariant.
     return {addInvariantExp(v), /*hasSpDep=*/false};
   }
+
   // Something defined outside is invariant.
   Operation *def = v.getDefiningOp();
   if (def->getBlock() != &op.getRegion().front())
@@ -1352,6 +1391,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       }
     }
   }
+
   // Construct binary operations if subexpressions can be built.
   // See buildLattices() for an explanation of rejecting certain
   // division and shift operations.
@@ -1447,6 +1487,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       }
     }
   }
+
   // Construct ternary operations if subexpressions can be built.
   if (def->getNumOperands() == 3) {
     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
@@ -1460,6 +1501,26 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         if (isAdmissibleBranch(redop, redop.getRegion()))
           return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
       }
+      if (auto selop = dyn_cast<arith::SelectOp>(def)) {
+        // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
+        // operation inside a very specific ternary select operation.
+        // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
+        const auto &cnd = exp(*x);
+        if (isGreater(cnd.kind, cnd.attr) &&
+            exp(*y).kind == TensorExp::Kind::kTensor &&
+            exp(*z).kind == TensorExp::Kind::kInvariant &&
+            isCertainZero(exp(*z).val)) {
+          const auto &a = exp(cnd.children.e0);
+          const auto &b = exp(cnd.children.e1);
+          if (a.kind == TensorExp::Kind::kTensor &&
+              a.tensor == exp(*y).tensor &&
+              b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
+            return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId,
+                           nullptr, cnd.attr),
+                    yDepSp};
+          }
+        }
+      }
     }
   }
 
@@ -1469,7 +1530,6 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   // tensors).
   if (def->getNumResults() != 1) // only handle single result operation.
     return {std::nullopt, false};
-
   SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
   // Builds all the sub-expressions
   for (Value operand : def->getOperands())
@@ -1489,6 +1549,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       return {e, false};
     }
   }
+
   // Cannot build.
   return {std::nullopt, false};
 }
@@ -1538,6 +1599,22 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
 }
 
+static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
+                       Attribute attr) {
+  Type tp = v0.getType();
+  auto zero =
+      rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
+  Value cmp;
+  if (isa<FloatType>(tp)) {
+    auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
+    cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero);
+  } else {
+    auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
+    cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero);
+  }
+  return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero);
+}
+
 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
                        Value v1) const {
   const auto &expr = exp(e);
@@ -1574,6 +1651,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
     return rewriter.create<math::Log1pOp>(loc, v0);
   case TensorExp::Kind::kLog1pC:
     return rewriter.create<complex::Log1pOp>(loc, v0);
+  case TensorExp::Kind::kRelu:
+    return buildRelu(rewriter, loc, v0, expr.attr);
   case TensorExp::Kind::kSinF:
     return rewriter.create<math::SinOp>(loc, v0);
   case TensorExp::Kind::kSinC:
@@ -1677,7 +1756,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
   case TensorExp::Kind::kUnary:
     return buildUnaryPresent(rewriter, loc, expr.op, v0);
   case TensorExp::Kind::kSelect:
-    return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
+    return insertYieldOp(rewriter, loc,
+                         cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
                          {v0});
   case TensorExp::Kind::kBinary:
     return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_relu.mlir b/mlir/test/Dialect/SparseTensor/sparse_relu.mlir
new file mode 100644
index 0000000000000..25f0c790b43d7
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_relu.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --sparsification-and-bufferization | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#sparse = #sparse_tensor.encoding<{
+    map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed)
+}>
+
+//
+// Make sure a simple ReLU passes the sparsifier
+//
+// CHECK-LABEL: func.func @relu
+// CHECK:       scf.for
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             arith.cmpf ugt
+// CHECK:             arith.select
+//
+func.func @relu(%arg0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
+  %cst = arith.constant 0.000000e+00 : f64
+  %0 = tensor.empty() : tensor<10x20x30xf64>
+  %1 = linalg.generic {
+      indexing_maps = [#map, #map],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%arg0 : tensor<10x20x30xf64, #sparse>)
+      outs(%0 : tensor<10x20x30xf64>) {
+  ^bb0(%in: f64, %out: f64):
+      %2 = arith.cmpf ugt, %in, %cst : f64
+      %3 = arith.select %2, %in, %cst : f64
+      linalg.yield %3 : f64
+  } -> tensor<10x20x30xf64>
+  %cast = tensor.cast %1 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
+  return %cast : tensor<10x20x30xf64, #sparse>
+}

@aartbik aartbik merged commit 70e227a into llvm:main May 13, 2024
1 of 4 checks passed
@aartbik aartbik deleted the bik branch May 13, 2024 21:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants