Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Handshake] Adding func instance op for integration #7812

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/circt/Dialect/Handshake/HandshakeOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "circt/Dialect/Handshake/HandshakeDialect.h"
#include "circt/Dialect/Handshake/HandshakeInterfaces.h"
#include "circt/Dialect/Seq/SeqTypes.h"
#include "circt/Support/LLVM.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down
51 changes: 51 additions & 0 deletions include/circt/Dialect/Handshake/HandshakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"

include "circt/Dialect/ESI/ESITypes.td"
include "circt/Dialect/Seq/SeqTypes.td"

// @mortbopet: some kind of support for interfaces as parent ops is currently
// being tracked here: https://github.com/llvm/llvm-project/pull/66196
class HasParentInterface<string interface>
Expand Down Expand Up @@ -138,6 +141,54 @@ def FuncOp : Op<Handshake_Dialect, "func", [
let hasCustomAssemblyFormat = 1;
}

def ESIInstanceOp : Op<Handshake_Dialect, "esi_instance", [
CallOpInterface,
HasClock,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Instantiate a Handshake circuit";
let description = [{
Instantiate (call) a Handshake function in a non-Handshake design using ESI
channels as the outside connections.
}];
let arguments = (ins FlatSymbolRefAttr:$module, StrAttr:$instName,
ClockType:$clk, I1:$rst,
Variadic<ChannelType>:$opOperands);
let results = (outs Variadic<ChannelType>);

let assemblyFormat = [{
$module $instName `clk` $clk `rst` $rst
`(` $opOperands `)` attr-dict `:` functional-type($opOperands, results)
}];

let extraClassDeclaration = [{
// Account for `clk` and `rst` operands vs call arguments.
static constexpr int NumFixedOperands = 2;

/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}

operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }

/// Return the module of this operation.
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("module");
}

/// Set the callee for this operation.
void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
(*this)->setAttr(getModuleAttrName(), callee.get<mlir::SymbolRefAttr>());
}

MutableOperandRange getArgOperandsMutable() {
return getOpOperandsMutable();
}
}];
}

// InstanceOp
def InstanceOp : Handshake_Op<"instance", [
CallOpInterface,
Expand Down
47 changes: 47 additions & 0 deletions lib/Conversion/HandshakeToHW/HandshakeToHW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/ESI/ESIOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/HWSymCache.h"
#include "circt/Dialect/HW/HWTypes.h"
#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "circt/Dialect/Handshake/HandshakePasses.h"
Expand Down Expand Up @@ -1054,6 +1055,40 @@ class InstanceConversionPattern
}
};

class ESIInstanceConversionPattern
: public OpConversionPattern<handshake::ESIInstanceOp> {
public:
ESIInstanceConversionPattern(MLIRContext *context,
const HWSymbolCache &symCache)
: OpConversionPattern(context), symCache(symCache) {}

LogicalResult
matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// The operand signature of this op is very similar to the lowered
// `handshake.func`s (especially since handshake uses ESI channels
// internally). Whereas ESIInstance ops have 'clk' and 'rst' at the
// beginning, lowered `handshake.func`s have them at the end. So we've just
// got to re-arrange them.
SmallVector<Value> operands;
for (size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
i < e; ++i)
operands.push_back(adaptor.getOperands()[i]);
operands.push_back(adaptor.getClk());
operands.push_back(adaptor.getRst());
// Locate the lowered module so the instance builder can get all the
// metadata.
Operation *targetModule = symCache.getDefinition(op.getModuleAttr());
// And replace the op with an instance of the target module.
rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
op.getInstNameAttr(), operands);
return success();
}

private:
const HWSymbolCache &symCache;
};

class ReturnConversionPattern
: public OpConversionPattern<handshake::ReturnOp> {
public:
Expand Down Expand Up @@ -1976,6 +2011,18 @@ class HandshakeToHWPass
for (auto hwModule : mod.getOps<hw::HWModuleOp>())
if (failed(convertExtMemoryOps(hwModule)))
return signalPassFailure();

// Run conversions which need see everything.
HWSymbolCache symbolCache;
symbolCache.addDefinitions(mod);
symbolCache.freeze();
RewritePatternSet patterns(mod.getContext());
patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
symbolCache);
if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
mod->emitOpError() << "error during conversion";
signalPassFailure();
}
}
};
} // end anonymous namespace
Expand Down
51 changes: 51 additions & 0 deletions lib/Dialect/Handshake/HandshakeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "circt/Dialect/ESI/ESITypes.h"
#include "circt/Support/LLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -1335,6 +1336,56 @@ void JoinOp::print(OpAsmPrinter &p) {
p << " : " << getData().getTypes();
}

LogicalResult
ESIInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the module attribute was specified.
auto fnAttr = this->getModuleAttr();
assert(fnAttr && "requires a 'module' symbol reference attribute");

FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
if (!fn)
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid handshake function";

// Verify that the operand and result types match the callee.
auto fnType = fn.getFunctionType();
if (fnType.getNumInputs() != getNumOperands() - NumFixedOperands)
return emitOpError(
"incorrect number of operands for the referenced handshake function");

for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
Type operandType = getOperand(i + NumFixedOperands).getType();
auto channelType = dyn_cast<esi::ChannelType>(operandType);
if (!channelType)
return emitOpError("operand type mismatch: expected channel type, but "
"provided ")
<< operandType << " for operand number " << i;
if (channelType.getInner() != fnType.getInput(i))
return emitOpError("operand type mismatch: expected operand type ")
<< fnType.getInput(i) << ", but provided "
<< getOperand(i).getType() << " for operand number " << i;
}

if (fnType.getNumResults() != getNumResults())
return emitOpError(
"incorrect number of results for the referenced handshake function");

for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
Type resultType = getResult(i).getType();
auto channelType = dyn_cast<esi::ChannelType>(resultType);
if (!channelType)
return emitOpError("result type mismatch: expected channel type, but "
"provided ")
<< resultType << " for result number " << i;
if (channelType.getInner() != fnType.getResult(i))
return emitOpError("result type mismatch: expected result type ")
<< fnType.getResult(i) << ", but provided "
<< getResult(i).getType() << " for result number " << i;
}

return success();
}

/// Based on mlir::func::CallOp::verifySymbolUses
LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the module attribute was specified.
Expand Down
11 changes: 11 additions & 0 deletions test/Conversion/HandshakeToHW/test_instance.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,14 @@ handshake.func @bar(%in : i32) -> (i32) {
%out = handshake.instance @foo(%in) : (i32) -> (i32)
handshake.return %out : i32
}

// -----

handshake.func @foo(%ctrl : i32) -> i32 {
return %ctrl : i32
}

hw.module @outer(in %clk: !seq.clock, in %rst: i1, in %ctrl: !esi.channel<i32>, out out: !esi.channel<i32>) {
%ret = handshake.esi_instance @foo "foo_inst" clk %clk rst %rst (%ctrl) : (!esi.channel<i32>) -> (!esi.channel<i32>)
hw.output %ret : !esi.channel<i32>
}
21 changes: 21 additions & 0 deletions test/Dialect/Handshake/call.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,24 @@ handshake.func @invalid_instance_op(%arg0 : i32) -> i32 {
instance @foo(%arg0) : (i32) -> (i32)
return %arg0 : i32
}

// -----

// CHECK-LABEL: handshake.func @foo(
// CHECK-SAME: %[[VAL_0:.*]]: i32, ...) -> i32
// CHECK: return %[[VAL_0]] : i32
// CHECK: }

// CHECK-LABEL: hw.module @outer(in %clk : !seq.clock, in %rst : i1, in %ctrl : !esi.channel<i32>, out out : !esi.channel<i32>) {
// CHECK-NEXT: [[R0:%.+]] = handshake.esi_instance @foo "foo_inst" clk %clk rst %rst(%ctrl) : (!esi.channel<i32>) -> !esi.channel<i32>
// CHECK-NEXT: hw.output [[R0]] : !esi.channel<i32>


handshake.func @foo(%ctrl : i32) -> i32 {
return %ctrl : i32
}

hw.module @outer(in %clk: !seq.clock, in %rst: i1, in %ctrl: !esi.channel<i32>, out out: !esi.channel<i32>) {
%ret = handshake.esi_instance @foo "foo_inst" clk %clk rst %rst (%ctrl) : (!esi.channel<i32>) -> (!esi.channel<i32>)
hw.output %ret : !esi.channel<i32>
}
Loading