Skip to content

Commit

Permalink
Split op folding logic out to its own file, NFC.
Browse files Browse the repository at this point in the history
  • Loading branch information
lattner committed Apr 30, 2020
1 parent 109bb24 commit 14681d5
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 218 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@ $ git clone git@github.com:llvm/llvm-project.git
$ git clone git@github.com:sifive/clattner-experimental.git cirt
```

3) HACK: Add symlink because I can't figure out how to get `LLVM_EXTERNAL_CIRT_SOURCE_DIR` to work with cmake:
3) HACK: Add symlink because I can't figure out how to get
`LLVM_EXTERNAL_CIRT_SOURCE_DIR` to work with cmake (I'd love help with
this!):

```
$ cd ~/Projects/llvm-project
$ ln -s ../cirt cirt
```

4) Configure the build to build MLIR and CIRT (MLIR is probably not necessary, but it builds
reasonably fast and is good to provide a sanity check that things are working):
4) Configure the build to build MLIR and CIRT using a command like this
(replace `/Users/chrisl` with the paths you want to use):

```
$ cd ~/Projects/llvm-project
Expand All @@ -51,7 +53,8 @@ To get something that runs fast, use `-DCMAKE_BUILD_TYPE=Release` or
you want debug info to go with it. `RELEASE` mode makes a very large difference
in performance.

5) Build MLIR and run MLIR tests as a smoketest:
5) Build MLIR and run MLIR tests as a smoketest - this isn't needed, but is
reasonably fast and a good sanity check:

```
$ ninja check-mlir
Expand Down
222 changes: 222 additions & 0 deletions lib/Dialect/FIRRTL/OpFolds.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
//===- OpFolds.cpp - Implement folds and canonicalizations for ops --------===//
//
//===----------------------------------------------------------------------===//

#include "cirt/Dialect/FIRRTL/Ops.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/Matchers.h"

using namespace cirt;
using namespace firrtl;

//===----------------------------------------------------------------------===//
// Fold Hooks
//===----------------------------------------------------------------------===//

struct ConstantIntMatcher {
APInt &value;
ConstantIntMatcher(APInt &value) : value(value) {}
bool match(Operation *op) {
if (auto cst = dyn_cast<ConstantOp>(op)) {
value = cst.value();
return true;
}
return false;
}
};

static inline ConstantIntMatcher m_FConstant(APInt &value) {
return ConstantIntMatcher(value);
}

// TODO: Move to DRR.
OpFoldResult AndPrimOp::fold(ArrayRef<Attribute> operands) {
APInt value;

/// and(x, 0) -> 0
if (matchPattern(rhs(), m_FConstant(value)) && value.isNullValue() &&
rhs().getType() == getType())
return rhs();

/// and(x, -1) -> x
if (matchPattern(rhs(), m_FConstant(value)) && value.isAllOnesValue() &&
lhs().getType() == getType())
return lhs();

/// and(x, x) -> x
if (lhs() == rhs() && rhs().getType() == getType())
return rhs();

return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a & b; });
}

OpFoldResult OrPrimOp::fold(ArrayRef<Attribute> operands) {
APInt value;

/// or(x, 0) -> x
if (matchPattern(rhs(), m_FConstant(value)) && value.isNullValue() &&
lhs().getType() == getType())
return lhs();

/// or(x, -1) -> -1
if (matchPattern(rhs(), m_FConstant(value)) && value.isAllOnesValue() &&
rhs().getType() == getType())
return rhs();

/// or(x, x) -> x
if (lhs() == rhs())
return rhs();

return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a | b; });
}

OpFoldResult XorPrimOp::fold(ArrayRef<Attribute> operands) {
APInt value;

/// xor(x, 0) -> x
if (matchPattern(rhs(), m_FConstant(value)) && value.isNullValue() &&
lhs().getType() == getType())
return lhs();

/// xor(x, x) -> 0
if (lhs() == rhs()) {
auto width = getType().cast<IntType>().getWidthOrSentinel();
if (width == -1)
width = 1;
auto type = IntegerType::get(width, getContext());
return Builder(getContext()).getZeroAttr(type);
}

return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a ^ b; });
}

OpFoldResult EQPrimOp::fold(ArrayRef<Attribute> operands) {
APInt value;

if (matchPattern(rhs(), m_FConstant(value))) {
APInt lhsCst;
// Constant fold.
if (matchPattern(lhs(), m_FConstant(lhsCst)) &&
value.getBitWidth() == lhsCst.getBitWidth()) {
auto result = value == lhsCst;
return IntegerAttr::get(IntegerType::get(1, getContext()),
APInt(1, result));
}

/// eq(x, 1) -> x when x is 1 bit.
/// TODO: Support SInt<1> on the LHS etc.
if (value.isAllOnesValue() && lhs().getType() == getType())
return lhs();

/// TODO: eq(x, 0) -> not(x) when x is 1 bit.
/// TODO: eq(x, 0) -> not(orr(x)) when x is >1 bit
/// TODO: eq(x, ~0) -> andr(x)) when x is >1 bit
}

return {};
}

OpFoldResult NEQPrimOp::fold(ArrayRef<Attribute> operands) {
APInt value;

if (matchPattern(rhs(), m_FConstant(value))) {
APInt lhsCst;
// Constant fold.
if (matchPattern(lhs(), m_FConstant(lhsCst)) &&
value.getBitWidth() == lhsCst.getBitWidth()) {
auto result = value != lhsCst;
return IntegerAttr::get(IntegerType::get(1, getContext()),
APInt(1, result));
}

/// neq(x, 0) -> x when x is 1 bit.
/// TODO: Support SInt<1> on the LHS etc.
if (value.isNullValue() && lhs().getType() == getType())
return lhs();

/// TODO: neq(x, 0) -> not(orr(x)) when x is >1 bit
/// TODO: neq(x, 1) -> not(x) when x is 1 bit.
/// TODO: neq(x, ~0) -> andr(x)) when x is >1 bit
}

return {};
}

OpFoldResult BitsPrimOp::fold(ArrayRef<Attribute> operands) {
APInt value;

// If we are extracting the entire input, then return it.
if (input().getType() == getType() &&
getType().cast<IntType>().getWidthOrSentinel() != -1)
return input();

return {};
}

OpFoldResult MuxPrimOp::fold(ArrayRef<Attribute> operands) {
APInt value;

/// mux(0/1, x, y) -> x or y
if (matchPattern(sel(), m_FConstant(value))) {
if (value.isNullValue() && low().getType() == getType())
return low();
if (!value.isNullValue() && high().getType() == getType())
return high();
}

// mux(cond, x, x) -> x
if (high() == low())
return high();

// mux(cond, x, cst)
if (matchPattern(low(), m_FConstant(value))) {
APInt c1;
// mux(cond, c1, c2)
if (matchPattern(high(), m_FConstant(c1))) {
// mux(cond, 1, 0) -> cond
if (c1.isOneValue() && value.isNullValue() &&
getType() == sel().getType())
return sel();

// TODO: x ? ~0 : 0 -> sext(x)
// TODO: "x ? c1 : c2" -> many tricks
}
// TODO: "x ? a : 0" -> sext(x) & a
}

// TODO: "x ? c1 : y" -> "~x ? y : c1"

return {};
}

OpFoldResult PadPrimOp::fold(ArrayRef<Attribute> operands) {
auto input = this->input();
auto inputType = input.getType().cast<IntType>();

// pad(x) -> x if the width doesn't change.
if (input.getType() == getType())
return input;

// Need to know the input width.
int32_t width = inputType.getWidthOrSentinel();
if (width == -1)
return {};

APInt value;

/// pad(cst1) -> cst2
if (matchPattern(input, m_FConstant(value))) {
auto destWidth = getType().cast<IntType>().getWidthOrSentinel();
if (inputType.isSigned())
value = value.sext(destWidth);
else
value = value.zext(destWidth);

return IntegerAttr::get(IntegerType::get(destWidth, getContext()), value);
}

return {};
}
Loading

0 comments on commit 14681d5

Please sign in to comment.