Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HandshakeToDC] Add pack/unpack lowering patterns #6941

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[DC] Add data-only pack/unpack operations
Adds data-only pack/unpack operations which operates on MLIR tuple-typed values. Don't want these in DC, but there really isn't any other place to put them; and they are required for an end-to-end handshake lowering path.

Add tuple packing dc-to-hw lowering

remove DC stuff, add tuple -> Hw conversion in Handshake to DC
  • Loading branch information
mortbopet authored and teqdruid committed Nov 13, 2024
commit 020798b24415a1e6fb1b019ad095cf691e53460f
93 changes: 84 additions & 9 deletions lib/Conversion/HandshakeToDC/HandshakeToDC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,33 @@ static Value pack(OpBuilder &b, Value token, Value data = {}) {
return b.create<dc::PackOp>(token.getLoc(), token, data);
}

// NOLINTNEXTLINE(misc-no-recursion)
static Type tupleToStruct(TupleType tuple) {
auto *ctx = tuple.getContext();
mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
for (auto [i, innerType] : llvm::enumerate(tuple)) {
Type convertedInnerType = innerType;
if (auto tupleInnerType = innerType.dyn_cast<TupleType>())
convertedInnerType = tupleToStruct(tupleInnerType);
hwfields.push_back(
{StringAttr::get(ctx, "field" + Twine(i)), convertedInnerType});
}

return hw::StructType::get(ctx, hwfields);
}

class DCTypeConverter : public TypeConverter {
public:
DCTypeConverter() {
addConversion([](Type type) -> Type {
if (isa<NoneType>(type))
return dc::TokenType::get(type.getContext());

// For pragmatic reasons, we use a struct type to represent tuples in the
// DC lowering; upstream MLIR doesn't have builtin type-modifying ops,
// so the next best thing is our "local" struct type in CIRCT.
if (auto tupleType = type.dyn_cast<TupleType>())
return dc::ValueType::get(type.getContext(), tupleToStruct(tupleType));
return dc::ValueType::get(type.getContext(), type);
});
addConversion([](ValueType type) { return type; });
Expand Down Expand Up @@ -249,6 +270,59 @@ class MergeOpConversion : public DCOpConversionPattern<handshake::MergeOp> {
}
};

class PackOpConversion : public DCOpConversionPattern<handshake::PackOp> {
public:
using DCOpConversionPattern<handshake::PackOp>::DCOpConversionPattern;
using OpAdaptor = typename handshake::PackOp::Adaptor;

LogicalResult
matchAndRewrite(handshake::PackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Like the join conversion, but also emits a dc.pack_tuple operation to
// handle the data side of the operation (since there's no upstream support
// for doing so, sigh...)
llvm::SmallVector<Value, 4> inputTokens, inputData;
for (auto input : adaptor.getOperands()) {
auto dct = unpack(rewriter, input);
inputTokens.push_back(dct.token);
if (dct.data)
inputData.push_back(dct.data);
}

auto join = rewriter.create<dc::JoinOp>(op.getLoc(), inputTokens);
auto structType = tupleToStruct(op.getResult().getType().cast<TupleType>());
auto packedData =
rewriter.create<hw::StructCreateOp>(op.getLoc(), structType, inputData);
convertedOps->insert(packedData);
rewriter.replaceOp(op, pack(rewriter, join, packedData));
return success();
}
};

class UnpackOpConversion : public DCOpConversionPattern<handshake::UnpackOp> {
public:
using DCOpConversionPattern<handshake::UnpackOp>::DCOpConversionPattern;
using OpAdaptor = typename handshake::UnpackOp::Adaptor;

LogicalResult
matchAndRewrite(handshake::UnpackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Unpack the !dc.value<tuple<...>> into the !dc.token and tuple<...>
// values.
DCTuple unpackedInput = unpack(rewriter, adaptor.getInput());
auto unpackedData =
rewriter.create<hw::StructExplodeOp>(op.getLoc(), unpackedInput.data);
convertedOps->insert(unpackedData);
// Re-pack each of the tuple elements with the token.
llvm::SmallVector<Value, 4> repackedInputs;
for (auto outputData : unpackedData.getResults())
repackedInputs.push_back(pack(rewriter, unpackedInput.token, outputData));

rewriter.replaceOp(op, repackedInputs);
return success();
}
};

class ControlMergeOpConversion
: public DCOpConversionPattern<handshake::ControlMergeOp> {
public:
Expand Down Expand Up @@ -584,9 +658,9 @@ class FuncOpConversion : public OpConversionPattern<handshake::FuncOp> {
// Replaces a handshake.func with a hw.module, converting the argument and
// result types using the provided type converter.
// @mortbopet: Not a fan of converting to hw here seeing as we don't
// necessarily have hardware semantics here. But, DC doesn't define a function
// operation, and there is no "func.graph_func" or any other generic function
// operation which is a graph region...
// necessarily have hardware semantics here. But, DC doesn't define a
// function operation, and there is no "func.graph_func" or any other
// generic function operation which is a graph region...
LogicalResult
matchAndRewrite(handshake::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -685,12 +759,13 @@ LogicalResult circt::handshaketodc::runHandshakeToDC(
// Add handshake conversion patterns.
// Note: merge/control merge are not supported - these are non-deterministic
// operators and we do not care for them.
patterns.add<BufferOpConversion, CondBranchConversionPattern,
SinkOpConversionPattern, SourceOpConversionPattern,
MuxOpConversionPattern, ForkOpConversionPattern,
JoinOpConversion, MergeOpConversion, ControlMergeOpConversion,
ConstantOpConversion, SyncOpConversion>(ctx, typeConverter,
&convertedOps);
patterns
.add<BufferOpConversion, CondBranchConversionPattern,
SinkOpConversionPattern, SourceOpConversionPattern,
MuxOpConversionPattern, ForkOpConversionPattern, JoinOpConversion,
PackOpConversion, UnpackOpConversion, MergeOpConversion,
ControlMergeOpConversion, ConstantOpConversion, SyncOpConversion>(
ctx, typeConverter, &convertedOps);

// ALL other single-result operations are converted via the
// UnitRateConversionPattern.
Expand Down
19 changes: 19 additions & 0 deletions test/Conversion/HandshakeToDC/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,22 @@ handshake.func @nonemerge(%arg0 : none, %arg1 : none) -> none {
%out = merge %arg0, %arg1 : none
return %out : none
}

// CHECK-LABEL: hw.module @pack_unpack(in
// CHECK-SAME: %[[VAL_0:.*]] : !dc.value<i32>, in %[[VAL_1:.*]] : !dc.value<i1>, out out0 : !dc.value<i32>, out out1 : !dc.value<i1>) {
// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = dc.unpack %[[VAL_0]] : !dc.value<i32>
// CHECK: %[[VAL_4:.*]], %[[VAL_5:.*]] = dc.unpack %[[VAL_1]] : !dc.value<i1>
// CHECK: %[[VAL_6:.*]] = dc.join %[[VAL_2]], %[[VAL_4]]
// CHECK: %[[VAL_7:.*]] = hw.struct_create (%[[VAL_3]], %[[VAL_5]]) : !hw.struct<field0: i32, field1: i1>
// CHECK: %[[VAL_8:.*]] = dc.pack %[[VAL_6]], %[[VAL_7]] : !hw.struct<field0: i32, field1: i1>
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = dc.unpack %[[VAL_8]] : !dc.value<!hw.struct<field0: i32, field1: i1>>
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]] = hw.struct_explode %[[VAL_10]] : !hw.struct<field0: i32, field1: i1>
// CHECK: %[[VAL_13:.*]] = dc.pack %[[VAL_9]], %[[VAL_11]] : i32
// CHECK: %[[VAL_14:.*]] = dc.pack %[[VAL_9]], %[[VAL_12]] : i1
// CHECK: hw.output %[[VAL_13]], %[[VAL_14]] : !dc.value<i32>, !dc.value<i1>
// CHECK: }
handshake.func @pack_unpack(%arg0 : i32, %arg1 : i1) -> (i32, i1) {
%packed = handshake.pack %arg0, %arg1 : tuple<i32, i1>
%a, %b = handshake.unpack %packed : tuple<i32, i1>
return %a, %b : i32, i1
}