Skip to content

Commit

Permalink
[Handshake] Persist handshake.FuncOp argument names (llvm#2048)
Browse files Browse the repository at this point in the history
This commit modifies the parsing of `handshake.func` operations to store the SSA argument names as a new attribute `argNames`. This attribute is then later used in `handshakeToFIRRTL` to guide argument naming.

This, unfortunately, does not solve how to persist names from `standard`->`handshake` - a `builtin.func` does not persist parsed SSA names. For reference, these are discarded at this point: https://github.com/llvm/llvm-project/blob/main/mlir/lib/IR/BuiltinDialect.cpp#L116-L119

If provided, an `argNames` attribute set for a `builtin.func` will be copied to the `handshake.func` during `StandardToHandshake`.
  • Loading branch information
mortbopet committed Nov 2, 2021
1 parent f8b1161 commit 6cb07bf
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 43 deletions.
24 changes: 4 additions & 20 deletions include/circt/Dialect/Handshake/HandshakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
// extra verification conditions. In particular, each Value must
// only have a single use. Also, it defines a Dominance-Free Scope
def FuncOp : Op<Handshake_Dialect, "func", [
NativeOpTrait<"IsIsolatedFromAbove">,
NativeOpTrait<"FunctionLike">,
IsolatedFromAbove,
FunctionLike,
Symbol,
RegionKindInterface
]> {
Expand All @@ -35,22 +35,13 @@ def FuncOp : Op<Handshake_Dialect, "func", [

let builders =
[OpBuilder<(ins "StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
CArg<"ArrayRef<NamedAttrList>", "{}">:$argAttrs),
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
$_state.addAttribute(SymbolTable::getSymbolAttrName(),
$_builder.getStringAttr(name));
$_state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
$_state.attributes.append(attrs.begin(), attrs.end());
$_state.addRegion();

// Not needed?? Arguments are already included in attrs
/* assert(type.getNumInputs() == argAttrs.size());
SmallString<8> argAttrName;
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (auto argDict = argAttrs[i].getDictionary())
$_state.addAttribute(getArgAttrName(i, argAttrName),
argDict);*/
}]>];

let extraClassDeclaration = [{
Expand All @@ -63,9 +54,6 @@ def FuncOp : Op<Handshake_Dialect, "func", [
.getValue()
.cast<FunctionType>();
}
// bool isVarArg() {
// return getType().getUnderlyingType()->isFunctionVarArg();
// }

// This trait needs access to the hooks defined below.
friend class OpTrait::FunctionLike<handshake::FuncOp>;
Expand Down Expand Up @@ -94,11 +82,7 @@ def FuncOp : Op<Handshake_Dialect, "func", [
}];

let verifier = [{ return ::verify$cppClass(*this); }];
let printer = [{
FunctionType fnType = getType();
mlir::function_like_impl::printFunctionLikeOp(p, *this, fnType.getInputs(),
/*isVariadic=*/true, fnType.getResults());
}];
let printer = [{ return ::print$cppClass(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}

Expand Down
10 changes: 9 additions & 1 deletion lib/Conversion/HandshakeToFIRRTL/HandshakeToFIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,19 @@ static void extractValues(ArrayRef<ValueVector *> valueVectors, size_t index,
static FModuleOp createTopModuleOp(handshake::FuncOp funcOp, unsigned numClocks,
ConversionPatternRewriter &rewriter) {
llvm::SmallVector<PortInfo, 8> ports;
auto argNames = funcOp->getAttrOfType<ArrayAttr>("argNames");
auto getArgumentName = [&](unsigned index) {
if (argNames && argNames.size() > index)
return rewriter.getStringAttr(
argNames[index].cast<StringAttr>().getValue());
else
return rewriter.getStringAttr("arg" + std::to_string(index));
};

// Add all inputs of funcOp.
unsigned argIndex = 0;
for (auto &arg : funcOp.getArguments()) {
auto portName = rewriter.getStringAttr("arg" + std::to_string(argIndex));
auto portName = getArgumentName(argIndex);
auto bundlePortType = getBundleType(arg.getType());

if (!bundlePortType)
Expand Down
80 changes: 63 additions & 17 deletions lib/Dialect/Handshake/HandshakeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,71 @@ static ParseResult verifyFuncOp(handshake::FuncOp op) {
return success();
}

/// Parses a FuncOp signature using
/// mlir::function_like_impl::parseFunctionSignature while getting access to the
/// parsed SSA names to store as attributes.
static ParseResult parseFuncOpArgs(
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &entryArgs,
SmallVectorImpl<Type> &argTypes, SmallVectorImpl<Attribute> &argNames,
SmallVectorImpl<NamedAttrList> &argAttrs, SmallVectorImpl<Type> &resTypes,
SmallVectorImpl<NamedAttrList> &resAttrs) {
auto *context = parser.getContext();

bool isVariadic;
if (mlir::function_like_impl::parseFunctionSignature(
parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
isVariadic, resTypes, resAttrs)
.failed())
return failure();

llvm::transform(entryArgs, std::back_inserter(argNames), [&](auto arg) {
return StringAttr::get(context, arg.name.drop_front());
});

return success();
}

static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
mlir::function_like_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
auto &builder = parser.getBuilder();
StringAttr nameAttr;
SmallVector<OpAsmParser::OperandType, 4> args;
SmallVector<Type, 4> argTypes, resTypes;
SmallVector<NamedAttrList, 4> argAttributes, resAttributes;
SmallVector<Attribute> argNames;

// Parse signature
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes) ||
parseFuncOpArgs(parser, args, argTypes, argNames, argAttributes, resTypes,
resAttributes))
return failure();
mlir::function_like_impl::addArgAndResultAttrs(builder, result, argAttributes,
resAttributes);

// Set function type
result.addAttribute(
handshake::FuncOp::getTypeAttrName(),
TypeAttr::get(builder.getFunctionType(argTypes, resTypes)));

// Parse attributes
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();

// If argNames wasn't provided manually, infer argNames attribute from the
// parsed SSA names.
if (!result.attributes.get("argNames"))
result.addAttribute("argNames", builder.getArrayAttr(argNames));

return mlir::function_like_impl::parseFunctionLikeOp(parser, result,
/*allowVariadic=*/true,
buildFuncType);
// Parse region
auto *body = result.addRegion();
return parser.parseRegion(*body, args, argTypes);
}

static void printFuncOp(OpAsmPrinter &p, handshake::FuncOp op) {
FunctionType fnType = op.getType();
mlir::function_like_impl::printFunctionLikeOp(p, op, fnType.getInputs(),
/*isVariadic=*/true,
fnType.getResults());
}

namespace {
Expand Down Expand Up @@ -423,7 +479,6 @@ bool handshake::ControlMergeOp::tryExecute(

void handshake::BranchOp::build(OpBuilder &builder, OperationState &result,
Value dataOperand) {

auto type = dataOperand.getType();
result.types.push_back(type);
result.addOperands(dataOperand);
Expand Down Expand Up @@ -461,7 +516,6 @@ void handshake::ConditionalBranchOp::build(OpBuilder &builder,
OperationState &result,
Value condOperand,
Value dataOperand) {

auto type = dataOperand.getType();
result.types.append(2, type);
result.addOperands(condOperand);
Expand Down Expand Up @@ -524,7 +578,6 @@ bool handshake::StartOp::tryExecute(
}

void EndOp::build(OpBuilder &builder, OperationState &result, Value operand) {

result.addOperands(operand);
}

Expand All @@ -539,12 +592,10 @@ bool handshake::EndOp::tryExecute(

void handshake::ReturnOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<Value> operands) {

result.addOperands(operands);
}

void SinkOp::build(OpBuilder &builder, OperationState &result, Value operand) {

result.addOperands(operand);
}

Expand All @@ -560,7 +611,6 @@ bool handshake::SinkOp::tryExecute(

void handshake::ConstantOp::build(OpBuilder &builder, OperationState &result,
Attribute value, Value operand) {

result.addOperands(operand);

auto type = value.getType();
Expand Down Expand Up @@ -600,7 +650,6 @@ void handshake::TerminatorOp::build(OpBuilder &builder, OperationState &result,
void MemoryOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<Value> operands, int outputs, int control_outputs,
bool lsq, int id, Value memref) {

result.addOperands(operands);

auto memrefType = memref.getType().cast<MemRefType>();
Expand Down Expand Up @@ -746,7 +795,6 @@ bool handshake::MemoryOp::tryExecute(

void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
Value memref, ArrayRef<Value> indices) {

// Address indices
// result.addOperands(memref);
result.addOperands(indices);
Expand Down Expand Up @@ -806,7 +854,6 @@ bool handshake::LoadOp::tryExecute(

void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
Value valueToStore, ArrayRef<Value> indices) {

// Data
result.addOperands(valueToStore);

Expand Down Expand Up @@ -838,7 +885,6 @@ bool handshake::StoreOp::tryExecute(

void JoinOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<Value> operands) {

auto type = builder.getNoneType();
result.types.push_back(type);

Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/HandshakeToFIRRTL/test_instance.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: circt-opt -lower-handshake-to-firrtl %s | FileCheck %s

// CHECK: firrtl.module @foo(in %[[VAL_0:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_1:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out %[[VAL_2:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_3:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in %[[VAL_4:.*]]: !firrtl.clock, in %[[VAL_5:.*]]: !firrtl.uint<1>) {
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = firrtl.instance "" @bar(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in clock: !firrtl.clock, in reset: !firrtl.uint<1>)
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = firrtl.instance "" @bar(in a: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in ctrl: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in clock: !firrtl.clock, in reset: !firrtl.uint<1>)
// CHECK: %[[VAL_12:.*]] = firrtl.instance "" @handshake_sink_1ins_0outs_ctrl(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>)
// CHECK: firrtl.connect %[[VAL_6]], %[[VAL_0]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_7]], %[[VAL_1]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
Expand All @@ -15,7 +15,7 @@
// CHECK: firrtl.module @handshake_sink_1ins_0outs_ctrl(in %[[VAL_13:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>) {

// CHECK: firrtl.module @bar(in %[[VAL_16:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in %[[VAL_17:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out %[[VAL_18:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out %[[VAL_19:.*]]: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in %[[VAL_20:.*]]: !firrtl.clock, in %[[VAL_21:.*]]: !firrtl.uint<1>) {
// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = firrtl.instance "" @baz(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in clock: !firrtl.clock, in reset: !firrtl.uint<1>)
// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = firrtl.instance "" @baz(in a: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, in ctrl: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, out arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in clock: !firrtl.clock, in reset: !firrtl.uint<1>)
// CHECK: %[[VAL_28:.*]] = firrtl.instance "" @handshake_sink_1ins_0outs_ctrl(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>)
// CHECK: firrtl.connect %[[VAL_22]], %[[VAL_16]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<32>>
// CHECK: firrtl.connect %[[VAL_23]], %[[VAL_17]] : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
Expand Down
14 changes: 14 additions & 0 deletions test/Conversion/HandshakeToFIRRTL/test_naming.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: circt-opt -lower-handshake-to-firrtl %s | FileCheck %s

// CHECK-LABEL: firrtl.module @main(
// CHECK: in %a: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>,
// CHECK: in %b: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>,
// CHECK: in %ctrl: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>,
// CHECK: out %arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>,
// CHECK: out %arg4: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>,
// CHECK: in %clock: !firrtl.clock,
// CHECK: in %reset: !firrtl.uint<1>) {
handshake.func @main(%a: index, %b: index, %ctrl: none, ...) -> (index, none) {
%0 = arith.addi %a, %b : index
handshake.return %0, %ctrl : index, none
}
4 changes: 2 additions & 2 deletions test/Conversion/HandshakeToFIRRTL/test_sink.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

// CHECK-LABEL: firrtl.module @test_sink(
// CHECK-SAME: in %arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>, in %arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, out %arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>) {
handshake.func @test_sink(%arg0: index, %arg2: none, ...) -> (none) {
handshake.func @test_sink(%arg0: index, %arg1: none, ...) -> (none) {

// CHECK: %inst_arg0 = firrtl.instance "" @handshake_sink_in_ui64(in arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>)
// CHECK: firrtl.connect %inst_arg0, %arg0 : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>
"handshake.sink"(%arg0) : (index) -> ()

// CHECK: firrtl.connect %arg2, %arg1 : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
handshake.return %arg2 : none
handshake.return %arg1 : none
}
2 changes: 1 addition & 1 deletion test/Conversion/StandardToHandshake/test_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module {

// CHECK-LABEL: handshake.func @simple(
// CHECK-SAME: %[[VAL_0:.*]]: none, ...) -> none {
// CHECK-SAME: %[[VAL_0:.*]]: none, ...) -> none attributes {argNames = ["arg0"]} {
handshake.func @simple(%arg0: none, ...) -> none {

// CHECK: %[[VAL_1:.*]] = "handshake.constant"(%[[VAL_0:.*]]) {value = 1 : index} : (none) -> index
Expand Down
6 changes: 6 additions & 0 deletions test/Conversion/StandardToHandshake/test_naming.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// RUN: circt-opt -lower-std-to-handshake %s | FileCheck %s

// CHECK-LABEL: handshake.func @main(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: none, ...) -> none attributes {argNames = ["a", "b", "c"]} {
func @main(%arg0 : i32, %b : i32, %c: i32) attributes {argNames = ["a", "b", "c"]} {
return
}

0 comments on commit 6cb07bf

Please sign in to comment.