Skip to content

Commit

Permalink
variadic identity (llvm#40)
Browse files Browse the repository at this point in the history
* added identity transformations for variadic operations and, or, xor, add, mul

* added variadic verifier

* ()[size - 1] --> ().back

* test the constant on the left case

* moved verifier TD --> Ops.cpp

* added zero operand variadic negative test

* fixed typo

* removed unused value1
  • Loading branch information
drom authored Jul 16, 2020
1 parent 9d2042f commit 2045dc3
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 53 deletions.
3 changes: 3 additions & 0 deletions include/circt/Dialect/RTL/Combinatorial.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ class UTVariadicRTLOp<string mnemonic, list<OpTrait> traits = []> :
!listconcat(traits,
[SameTypeOperands, SameOperandsAndResultType])> {

let hasCanonicalizer = 1;
let hasFolder = 1;
let verifier = [{ return ::verifyUTVariadicRTLOp(*this); }];

let assemblyFormat = [{
$inputs attr-dict `:` type($result)
}];
Expand Down
215 changes: 173 additions & 42 deletions lib/Dialect/RTL/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"

using namespace circt;
Expand Down Expand Up @@ -157,105 +158,235 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
// Variadic operations
//===----------------------------------------------------------------------===//

static LogicalResult verifyUTVariadicRTLOp(Operation *op) {
auto size = op->getOperands().size();
if (size < 1)
return op->emitOpError("requires 1 or more args");

return success();
}

OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
auto size = inputs().size();

// and(x) -> x -- noop
if (size == 1u)
return inputs()[0];

APInt value;

// and(..., 0) -> 0 -- annulment
if (matchPattern(inputs()[size - 1], m_RConstant(value)) &&
value.isNullValue())
return inputs()[size - 1];

/// TODO: and(..., '1) -> and(...) -- identity
/// TODO: and(..., x, x) -> and(..., x) -- idempotent
/// TODO: and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
/// folding
/// TODO: and(x, and(...)) -> and(x, ...) -- flatten
/// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
if (matchPattern(inputs().back(), m_RConstant(value)) && value.isNullValue())
return inputs().back();

return {};
}

void AndOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
struct Folder final : public OpRewritePattern<AndOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AndOp op,
PatternRewriter &rewriter) const override {
auto inputs = op.inputs();
auto size = inputs.size();
assert(size > 1 && "expected 2 or more operands");

APInt value;

// and(..., '1) -> and(...) -- identity
if (matchPattern(inputs.back(), m_RConstant(value)) &&
value.isAllOnesValue()) {

rewriter.replaceOpWithNewOp<AndOp>(op, op.getType(),
inputs.drop_back());
return success();
}

/// TODO: and(..., c1, c2) -> and(..., c3) -- constant folding
/// TODO: and(x, and(...)) -> and(x, ...) -- flatten
/// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
/// TODO: and(..., x, x) -> and(..., x) -- idempotent
return failure();
}
};
results.insert<Folder>(context);
}

OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
auto size = inputs().size();

// or(x) -> x -- noop
if (size == 1u)
return inputs()[0];

APInt value;

// or(..., '1) -> '1 -- annulment
if (matchPattern(inputs()[size - 1], m_RConstant(value)) &&
if (matchPattern(inputs().back(), m_RConstant(value)) &&
value.isAllOnesValue())
return inputs()[size - 1];

/// TODO: or(..., 0) -> or(...) -- identity
/// TODO: or(..., x, x) -> or(..., x) -- idempotent
/// TODO: or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant
/// folding
/// TODO: or(x, or(...)) -> or(x, ...) -- flatten
/// TODO: or(..., x, not(x)) -> or(..., '1) -- complement

return inputs().back();
return {};
}

void OrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
struct Folder final : public OpRewritePattern<OrOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OrOp op,
PatternRewriter &rewriter) const override {
auto inputs = op.inputs();
auto size = inputs.size();
assert(size > 1 && "expected 2 or more operands");

APInt value;

// or(..., 0) -> or(...) -- identity
if (matchPattern(inputs.back(), m_RConstant(value)) &&
value.isNullValue()) {

rewriter.replaceOpWithNewOp<OrOp>(op, op.getType(), inputs.drop_back());
return success();
}
/// TODO: or(..., x, x) -> or(..., x) -- idempotent
/// TODO: or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant
/// folding
/// TODO: or(x, or(...)) -> or(x, ...) -- flatten
/// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
return failure();
}
};
results.insert<Folder>(context);
}

OpFoldResult XorOp::fold(ArrayRef<Attribute> operands) {
auto size = inputs().size();

// xor(x) -> x -- noop
if (size == 1u)
return inputs()[0];

/// TODO: xor(..., 0) -> xor(...) -- identity
/// TODO: xor(..., '1) -> not(xor(...))
/// TODO: xor(..., x, x) -> xor(..., 0) -- idempotent?
/// TODO: xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2 -- constant
/// folding
/// TODO: xor(x, xor(...)) -> xor(x, ...) -- flatten
/// TODO: xor(..., x, not(x)) -> xor(..., '1)

return {};
}

void XorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
struct Folder final : public OpRewritePattern<XorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(XorOp op,
PatternRewriter &rewriter) const override {
auto inputs = op.inputs();
auto size = inputs.size();
assert(size > 1 && "expected 2 or more operands");

APInt value;

// xor(..., 0) -> xor(...) -- identity
if (matchPattern(inputs.back(), m_RConstant(value)) &&
value.isNullValue()) {

rewriter.replaceOpWithNewOp<XorOp>(op, op.getType(),
inputs.drop_back());
return success();
}

/// TODO: xor(..., '1) -> not(xor(...))
/// TODO: xor(..., x, x) -> xor(..., 0) -- idempotent?
/// TODO: xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2 --
/// constant folding
/// TODO: xor(x, xor(...)) -> xor(x, ...) -- flatten
/// TODO: xor(..., x, not(x)) -> xor(..., '1)
return failure();
}
};
results.insert<Folder>(context);
}

OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
auto size = inputs().size();

// add(x) -> x -- noop
if (size == 1u)
return inputs()[0];

/// TODO: add(..., 0) -> add(...) -- identity
/// TODO: add(..., x, x) -> add(..., shl(x, 1))
/// TODO: add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant
/// folding
/// TODO: add(x, add(...)) -> add(x, ...) -- flatten

return {};
}

void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
struct Folder final : public OpRewritePattern<AddOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AddOp op,
PatternRewriter &rewriter) const override {
auto inputs = op.inputs();
auto size = inputs.size();
assert(size > 1 && "expected 2 or more operands");

APInt value;

// add(..., 0) -> add(...) -- identity
if (matchPattern(inputs.back(), m_RConstant(value)) &&
value.isNullValue()) {
rewriter.replaceOpWithNewOp<AddOp>(op, op.getType(),
inputs.drop_back());
return success();
}

/// TODO: add(..., x, x) -> add(..., shl(x, 1))
/// TODO: add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 --
/// constant folding
/// TODO: add(x, add(...)) -> add(x, ...) -- flatten
return failure();
}
};
results.insert<Folder>(context);
}

OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
auto size = inputs().size();

// mul(x) -> x -- noop
if (size == 1u)
return inputs()[0];

APInt value;

// mul(..., 0) -> 0 -- annulment
if (matchPattern(inputs()[size-1], m_RConstant(value)) &&
value.isNullValue())
return inputs()[size-1];

/// TODO: mul(..., 1) -> mul(...) -- identity
/// TODO: mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant
/// folding
/// TODO: mul(a, mul(...)) -> mul(a, ...) -- flatten
if (matchPattern(inputs().back(), m_RConstant(value)) && value.isNullValue())
return inputs().back();

return {};
}

void MulOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
struct Folder final : public OpRewritePattern<MulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(MulOp op,
PatternRewriter &rewriter) const override {
auto inputs = op.inputs();
auto size = inputs.size();
assert(size > 1 && "expected 2 or more operands");

APInt value;

// mul(..., 1) -> mul(...) -- identity
if (matchPattern(inputs.back(), m_RConstant(value)) && (value == 1u)) {
rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(),
inputs.drop_back());
return success();
}

/// TODO: mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 --
/// constant folding
/// TODO: mul(a, mul(...)) -> mul(a, ...) -- flatten

return failure();
}
};
results.insert<Folder>(context);
}

//===----------------------------------------------------------------------===//
// TableGen generated logic.
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 2045dc3

Please sign in to comment.