-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sparse] recognize ReLu operation during sparsification #92016
Conversation
This is a proof of concept recognition of the most basic forms of ReLu operations, used to show-case sparsification of end-to-end PyTorch models. In the long run, we must avoid lowering such constructs too early (with this need for raising them back). See discussion at https://discourse.llvm.org/t/min-max-abs-relu-recognition-starter-project/78918
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Aart Bik (aartbik) ChangesThis is a proof of concept recognition of the most basic forms of ReLu operations, used to show-case sparsification of end-to-end PyTorch models. In the long run, we must avoid lowering such constructs too early (with this need for raising them back). See discussion at Full diff: https://github.com/llvm/llvm-project/pull/92016.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 7f9820df984b2..b8d278152dc05 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -144,6 +144,7 @@ enum class TensorExp::Kind {
kExpm1C,
kLog1pF,
kLog1pC,
+ kRelu,
kSinF,
kSinC,
kTanhF,
@@ -316,7 +317,7 @@ class Merger {
/// lattice point on an expression E is simply copied over, but with OP E
/// as new expression. Returns the identifier of the new set.
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(),
- Operation *op = nullptr);
+ Operation *op = nullptr, Attribute attr = nullptr);
/// Maps the binary operator to the same operation but with one of its operand
/// set to zero, i.e. each lattice point on an expression E is simply copied
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 308fbd965259d..0258f797143cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -44,6 +44,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -104,7 +105,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
Operation *o, Attribute a)
- : kind(k), val(v), op(o) {
+ : kind(k), val(v), op(o), attr(a) {
switch (kind) {
// Leaf.
case TensorExp::Kind::kTensor:
@@ -133,6 +134,7 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -201,7 +203,6 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
case TensorExp::Kind::kCmpF:
case TensorExp::Kind::kCmpI:
assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
- attr = a;
children.e0 = x;
children.e1 = y;
return;
@@ -337,7 +338,6 @@ LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
const LatSetId sNew = conjSet(e, s0, s1, op);
TensorExp::Kind kind = exp(e).kind;
-
// Followed by all in s0.
latSets[sNew].append(latSets[s0]);
// Map binary 0-y to unary -y.
@@ -381,31 +381,32 @@ LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
bool includeLeft, TensorExp::Kind ltrans,
Operation *opleft, bool includeRight,
TensorExp::Kind rtrans, Operation *opright) {
+ Attribute a = exp(e).attr;
const LatSetId sNew = conjSet(e, s0, s1, orig);
// Left Region.
if (includeLeft) {
if (opleft)
- s0 = mapSet(ltrans, s0, Value(), opleft);
+ s0 = mapSet(ltrans, s0, Value(), opleft, a);
latSets[sNew].append(latSets[s0]);
}
// Right Region.
if (includeRight) {
if (opright)
- s1 = mapSet(rtrans, s1, Value(), opright);
+ s1 = mapSet(rtrans, s1, Value(), opright, a);
latSets[sNew].append(latSets[s1]);
}
return sNew;
}
LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
- Operation *op) {
+ Operation *op, Attribute a) {
assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
TensorExp::Kind::kDenseOp == kind);
const LatSetId sNew = addSet();
auto &setNew = latSets[sNew];
for (const LatPointId p : set(s0)) {
const auto &point = latPoints[p];
- setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
+ setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
}
return sNew;
}
@@ -596,6 +597,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -717,6 +719,8 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
return "log1p";
+ case TensorExp::Kind::kRelu:
+ return "relu";
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
return "sin";
@@ -824,6 +828,7 @@ void Merger::dumpExp(ExprId e) const {
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -972,6 +977,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -1001,7 +1007,8 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
{
const ExprId e0 = expr.children.e0;
const Value v = expr.val;
- return mapSet(kind, buildLattices(e0, i), v);
+ Attribute a = expr.attr;
+ return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
}
case TensorExp::Kind::kBinaryBranch:
case TensorExp::Kind::kSelect:
@@ -1190,10 +1197,26 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
return buildTensorExp(op, yield->getOperand(0)).first;
}
+/// Only returns true if we are certain this is a zero.
+static bool isCertainZero(Value val) {
+ if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
+ ArrayAttr arrayAttr = c.getValue();
+ return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
+ cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
+ }
+ if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
+ return c.value() == 0;
+ if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
+ return c.value().isZero();
+ return false;
+}
+
/// Only returns false if we are certain this is a nonzero.
bool Merger::maybeZero(ExprId e) const {
const auto &expr = exp(e);
if (expr.kind == TensorExp::Kind::kInvariant) {
+ // Note that this is different from isCertainZero() in a subtle
+ // way by always returning true for non-constants.
if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
ArrayAttr arrayAttr = c.getValue();
return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
@@ -1247,6 +1270,21 @@ static bool isAdmissibleBranch(Operation *op, Region ®ion) {
return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0));
}
+// Recognizes a direct GT comparison.
+static bool isGreater(TensorExp::Kind kind, Attribute attr) {
+ if (kind == TensorExp::Kind::kCmpI) {
+ auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
+ return pred == arith::CmpIPredicate::ugt ||
+ pred == arith::CmpIPredicate::sgt;
+ }
+ if (kind == TensorExp::Kind::kCmpF) {
+ auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
+ return pred == arith::CmpFPredicate::UGT ||
+ pred == arith::CmpFPredicate::OGT;
+ }
+ return false;
+}
+
std::pair<std::optional<ExprId>, bool>
Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// Recursion leaves.
@@ -1266,6 +1304,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// or belonging to an enveloping op) is considered invariant.
return {addInvariantExp(v), /*hasSpDep=*/false};
}
+
// Something defined outside is invariant.
Operation *def = v.getDefiningOp();
if (def->getBlock() != &op.getRegion().front())
@@ -1352,6 +1391,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
}
}
+
// Construct binary operations if subexpressions can be built.
// See buildLattices() for an explanation of rejecting certain
// division and shift operations.
@@ -1447,6 +1487,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
}
}
+
// Construct ternary operations if subexpressions can be built.
if (def->getNumOperands() == 3) {
const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
@@ -1460,6 +1501,26 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
if (isAdmissibleBranch(redop, redop.getRegion()))
return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
}
+ if (auto selop = dyn_cast<arith::SelectOp>(def)) {
+ // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
+ // operation inside a very specific ternary select operation.
+ // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
+ const auto &cnd = exp(*x);
+ if (isGreater(cnd.kind, cnd.attr) &&
+ exp(*y).kind == TensorExp::Kind::kTensor &&
+ exp(*z).kind == TensorExp::Kind::kInvariant &&
+ isCertainZero(exp(*z).val)) {
+ const auto &a = exp(cnd.children.e0);
+ const auto &b = exp(cnd.children.e1);
+ if (a.kind == TensorExp::Kind::kTensor &&
+ a.tensor == exp(*y).tensor &&
+ b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
+ return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId,
+ nullptr, cnd.attr),
+ yDepSp};
+ }
+ }
+ }
}
}
@@ -1469,7 +1530,6 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// tensors).
if (def->getNumResults() != 1) // only handle single result operation.
return {std::nullopt, false};
-
SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
// Builds all the sub-expressions
for (Value operand : def->getOperands())
@@ -1489,6 +1549,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return {e, false};
}
}
+
// Cannot build.
return {std::nullopt, false};
}
@@ -1538,6 +1599,22 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
}
+static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
+ Attribute attr) {
+ Type tp = v0.getType();
+ auto zero =
+ rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
+ Value cmp;
+ if (isa<FloatType>(tp)) {
+ auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
+ cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero);
+ } else {
+ auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
+ cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero);
+ }
+ return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero);
+}
+
Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
Value v1) const {
const auto &expr = exp(e);
@@ -1574,6 +1651,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
return rewriter.create<math::Log1pOp>(loc, v0);
case TensorExp::Kind::kLog1pC:
return rewriter.create<complex::Log1pOp>(loc, v0);
+ case TensorExp::Kind::kRelu:
+ return buildRelu(rewriter, loc, v0, expr.attr);
case TensorExp::Kind::kSinF:
return rewriter.create<math::SinOp>(loc, v0);
case TensorExp::Kind::kSinC:
@@ -1677,7 +1756,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
case TensorExp::Kind::kUnary:
return buildUnaryPresent(rewriter, loc, expr.op, v0);
case TensorExp::Kind::kSelect:
- return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
+ return insertYieldOp(rewriter, loc,
+ cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
{v0});
case TensorExp::Kind::kBinary:
return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_relu.mlir b/mlir/test/Dialect/SparseTensor/sparse_relu.mlir
new file mode 100644
index 0000000000000..25f0c790b43d7
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_relu.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --sparsification-and-bufferization | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#sparse = #sparse_tensor.encoding<{
+ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed)
+}>
+
+//
+// Make sure a simple ReLU passes the sparsifier
+//
+// CHECK-LABEL: func.func @relu
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: arith.cmpf ugt
+// CHECK: arith.select
+//
+func.func @relu(%arg0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
+ %cst = arith.constant 0.000000e+00 : f64
+ %0 = tensor.empty() : tensor<10x20x30xf64>
+ %1 = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<10x20x30xf64, #sparse>)
+ outs(%0 : tensor<10x20x30xf64>) {
+ ^bb0(%in: f64, %out: f64):
+ %2 = arith.cmpf ugt, %in, %cst : f64
+ %3 = arith.select %2, %in, %cst : f64
+ linalg.yield %3 : f64
+ } -> tensor<10x20x30xf64>
+ %cast = tensor.cast %1 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
+ return %cast : tensor<10x20x30xf64, #sparse>
+}
|
This is a proof of concept recognition of the most basic forms of ReLu operations, used to show-case sparsification of end-to-end PyTorch models. In the long run, we must avoid lowering such constructs too early (with this need for raising them back).
See discussion at
https://discourse.llvm.org/t/min-max-abs-relu-recognition-starter-project/78918