Skip to content

Commit

Permalink
Revert "Revert "[LowerToHW] Lower aggregate constant (#4451)" (#4465)"
Browse files Browse the repository at this point in the history
This reverts commit b4e042d.
  • Loading branch information
uenoku committed Feb 1, 2023
1 parent aabaffa commit 0243177
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
56 changes: 41 additions & 15 deletions lib/Conversion/FIRRTLToHW/LowerToHW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,7 @@ struct FIRRTLLowering : public FIRRTLVisitor<FIRRTLLowering, LogicalResult> {
bool isSigned = false) {
return getOrCreateIntConstant(APInt(numBits, val, isSigned));
}
Attribute getOrCreateAggregateConstantAttribute(Attribute value, Type type);
Value getOrCreateXConstant(unsigned numBits);
Value getPossiblyInoutLoweredValue(Value value);
Value getLoweredValue(Value value);
Expand Down Expand Up @@ -1653,6 +1654,7 @@ struct FIRRTLLowering : public FIRRTLVisitor<FIRRTLLowering, LogicalResult> {
/// This keeps track of constants that we have created so we can reuse them.
/// This is populated by the getOrCreateIntConstant method.
DenseMap<Attribute, Value> hwConstantMap;
DenseMap<std::pair<Attribute, Type>, Attribute> hwAggregateConstantMap;

/// This keeps track of constant X that we have created so we can reuse them.
/// This is populated by the getOrCreateXConstant method.
Expand Down Expand Up @@ -1842,6 +1844,41 @@ Value FIRRTLLowering::getOrCreateIntConstant(const APInt &value) {
return entry;
}

/// Check to see if we've already created the specified aggregate constant
/// attribute. If so, return it. Otherwise create it.
Attribute FIRRTLLowering::getOrCreateAggregateConstantAttribute(Attribute value,
Type type) {
// Base case.
if (hw::type_isa<IntegerType>(type))
return value;

auto cache = hwAggregateConstantMap.lookup({value, type});
if (cache)
return cache;

// Recursively construct elements.
SmallVector<Attribute> values;
for (auto &e : llvm::enumerate(value.cast<ArrayAttr>())) {
Type subType;
if (auto array = hw::type_dyn_cast<hw::ArrayType>(type))
subType = array.getElementType();
else if (auto structType = hw::type_dyn_cast<hw::StructType>(type))
subType = structType.getElements()[e.index()].type;
else
assert(false && "type must be either array or struct");

values.push_back(getOrCreateAggregateConstantAttribute(e.value(), subType));
}

// FIRRTL and HW have a different operand ordering for arrays.
if (hw::type_isa<hw::ArrayType>(type))
std::reverse(values.begin(), values.end());

auto &entry = hwAggregateConstantMap[{value, type}];
entry = builder.getArrayAttr(values);
return entry;
}

/// Zero bit operands end up looking like failures from getLoweredValue. This
/// helper function invokes the closure specified if the operand was actually
/// zero bit, or returns failure() if it was some other kind of failure.
Expand Down Expand Up @@ -2551,22 +2588,11 @@ LogicalResult FIRRTLLowering::visitExpr(BundleCreateOp op) {

LogicalResult FIRRTLLowering::visitExpr(AggregateConstantOp op) {
auto resultType = lowerType(op.getResult().getType());
auto vec = op.getType().dyn_cast<FVectorType>();
// Currently we only support 1d vector types.
if (!vec || !vec.getElementType().isa<IntType>()) {
op.emitError()
<< "has an unsupported type; currently we only support 1d vectors";
return failure();
}
auto attr =
getOrCreateAggregateConstantAttribute(op.getFieldsAttr(), resultType);

// TODO: Use hw aggregate constant
SmallVector<Value> operands;
// Make sure to reverse the operands.
for (auto elem : llvm::reverse(op.getFields()))
operands.push_back(
getOrCreateIntConstant(elem.cast<IntegerAttr>().getValue()));

return setLoweringTo<hw::ArrayCreateOp>(op, resultType, operands);
return setLoweringTo<hw::AggregateConstantOp>(op, resultType,
attr.cast<ArrayAttr>());
}

//===----------------------------------------------------------------------===//
Expand Down
12 changes: 7 additions & 5 deletions test/Conversion/FIRRTLToHW/lower-to-hw.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1788,11 +1788,13 @@ firrtl.circuit "Simple" attributes {annotations = [{class =
}

// CHECK-LABEL: hw.module @aggregateconstant
firrtl.module @aggregateconstant(out %out : !firrtl.vector<uint<8>, 2>) {
%0 = firrtl.aggregateconstant [1 : ui8, 0: ui8] : !firrtl.vector<uint<8>, 2>
firrtl.strictconnect %out, %0 : !firrtl.vector<uint<8>, 2>
// CHECK: %0 = hw.aggregate_constant [0 : i8, 1 : i8] : !hw.array<2xi8>
// CHECK-NEXT: hw.output %0 : !hw.array<2xi8>
firrtl.module @aggregateconstant(out %out : !firrtl.bundle<a: vector<vector<uint<8>, 2>, 2>, b: vector<vector<uint<8>, 2>, 2>>) {
%0 = firrtl.aggregateconstant [[[0 : ui8, 1: ui8], [2 : ui8, 3: ui8]], [[4: ui8, 5: ui8], [6: ui8, 7:ui8]]] :
!firrtl.bundle<a: vector<vector<uint<8>, 2>, 2>, b: vector<vector<uint<8>, 2>, 2>>
firrtl.strictconnect %out, %0 : !firrtl.bundle<a: vector<vector<uint<8>, 2>, 2>, b: vector<vector<uint<8>, 2>, 2>>
// CHECK-LITERAL: %0 = hw.aggregate_constant [[[3 : ui8, 2 : ui8], [1 : ui8, 0 : ui8]], [[7 : ui8, 6 : ui8], [5 : ui8, 4 : ui8]]]
// CHECK-SAME: !hw.struct<a: !hw.array<2xarray<2xi8>>, b: !hw.array<2xarray<2xi8>>>
// CHECK: hw.output %0
}

// CHECK-LABEL: hw.module @intrinsic
Expand Down

0 comments on commit 0243177

Please sign in to comment.