From 0243177a750478f4b5f446ee72bcde933025300c Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Wed, 1 Feb 2023 03:37:21 -0800 Subject: [PATCH] Revert "Revert "[LowerToHW] Lower aggregate constant (#4451)" (#4465)" This reverts commit b4e042dd2ae9c718a09ad228785607d72d6dc1b6. --- lib/Conversion/FIRRTLToHW/LowerToHW.cpp | 56 +++++++++++++++------ test/Conversion/FIRRTLToHW/lower-to-hw.mlir | 12 +++-- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/lib/Conversion/FIRRTLToHW/LowerToHW.cpp b/lib/Conversion/FIRRTLToHW/LowerToHW.cpp index 01a3fa7e1400..8bc68a229653 100644 --- a/lib/Conversion/FIRRTLToHW/LowerToHW.cpp +++ b/lib/Conversion/FIRRTLToHW/LowerToHW.cpp @@ -1441,6 +1441,7 @@ struct FIRRTLLowering : public FIRRTLVisitor { bool isSigned = false) { return getOrCreateIntConstant(APInt(numBits, val, isSigned)); } + Attribute getOrCreateAggregateConstantAttribute(Attribute value, Type type); Value getOrCreateXConstant(unsigned numBits); Value getPossiblyInoutLoweredValue(Value value); Value getLoweredValue(Value value); @@ -1653,6 +1654,7 @@ struct FIRRTLLowering : public FIRRTLVisitor { /// This keeps track of constants that we have created so we can reuse them. /// This is populated by the getOrCreateIntConstant method. DenseMap hwConstantMap; + DenseMap, Attribute> hwAggregateConstantMap; /// This keeps track of constant X that we have created so we can reuse them. /// This is populated by the getOrCreateXConstant method. @@ -1842,6 +1844,41 @@ Value FIRRTLLowering::getOrCreateIntConstant(const APInt &value) { return entry; } +/// Check to see if we've already created the specified aggregate constant +/// attribute. If so, return it. Otherwise create it. +Attribute FIRRTLLowering::getOrCreateAggregateConstantAttribute(Attribute value, + Type type) { + // Base case. + if (hw::type_isa(type)) + return value; + + auto cache = hwAggregateConstantMap.lookup({value, type}); + if (cache) + return cache; + + // Recursively construct elements. + SmallVector values; + for (auto &e : llvm::enumerate(value.cast())) { + Type subType; + if (auto array = hw::type_dyn_cast(type)) + subType = array.getElementType(); + else if (auto structType = hw::type_dyn_cast(type)) + subType = structType.getElements()[e.index()].type; + else + assert(false && "type must be either array or struct"); + + values.push_back(getOrCreateAggregateConstantAttribute(e.value(), subType)); + } + + // FIRRTL and HW have a different operand ordering for arrays. + if (hw::type_isa(type)) + std::reverse(values.begin(), values.end()); + + auto &entry = hwAggregateConstantMap[{value, type}]; + entry = builder.getArrayAttr(values); + return entry; +} + /// Zero bit operands end up looking like failures from getLoweredValue. This /// helper function invokes the closure specified if the operand was actually /// zero bit, or returns failure() if it was some other kind of failure. @@ -2551,22 +2588,11 @@ LogicalResult FIRRTLLowering::visitExpr(BundleCreateOp op) { LogicalResult FIRRTLLowering::visitExpr(AggregateConstantOp op) { auto resultType = lowerType(op.getResult().getType()); - auto vec = op.getType().dyn_cast(); - // Currently we only support 1d vector types. - if (!vec || !vec.getElementType().isa()) { - op.emitError() - << "has an unsupported type; currently we only support 1d vectors"; - return failure(); - } + auto attr = + getOrCreateAggregateConstantAttribute(op.getFieldsAttr(), resultType); - // TODO: Use hw aggregate constant - SmallVector operands; - // Make sure to reverse the operands. - for (auto elem : llvm::reverse(op.getFields())) - operands.push_back( - getOrCreateIntConstant(elem.cast().getValue())); - - return setLoweringTo(op, resultType, operands); + return setLoweringTo(op, resultType, + attr.cast()); } //===----------------------------------------------------------------------===// diff --git a/test/Conversion/FIRRTLToHW/lower-to-hw.mlir b/test/Conversion/FIRRTLToHW/lower-to-hw.mlir index 0eac57895a26..ba9a32ab09bc 100644 --- a/test/Conversion/FIRRTLToHW/lower-to-hw.mlir +++ b/test/Conversion/FIRRTLToHW/lower-to-hw.mlir @@ -1788,11 +1788,13 @@ firrtl.circuit "Simple" attributes {annotations = [{class = } // CHECK-LABEL: hw.module @aggregateconstant - firrtl.module @aggregateconstant(out %out : !firrtl.vector, 2>) { - %0 = firrtl.aggregateconstant [1 : ui8, 0: ui8] : !firrtl.vector, 2> - firrtl.strictconnect %out, %0 : !firrtl.vector, 2> - // CHECK: %0 = hw.aggregate_constant [0 : i8, 1 : i8] : !hw.array<2xi8> - // CHECK-NEXT: hw.output %0 : !hw.array<2xi8> + firrtl.module @aggregateconstant(out %out : !firrtl.bundle, 2>, 2>, b: vector, 2>, 2>>) { + %0 = firrtl.aggregateconstant [[[0 : ui8, 1: ui8], [2 : ui8, 3: ui8]], [[4: ui8, 5: ui8], [6: ui8, 7:ui8]]] : + !firrtl.bundle, 2>, 2>, b: vector, 2>, 2>> + firrtl.strictconnect %out, %0 : !firrtl.bundle, 2>, 2>, b: vector, 2>, 2>> + // CHECK-LITERAL: %0 = hw.aggregate_constant [[[3 : ui8, 2 : ui8], [1 : ui8, 0 : ui8]], [[7 : ui8, 6 : ui8], [5 : ui8, 4 : ui8]]] + // CHECK-SAME: !hw.struct>, b: !hw.array<2xarray<2xi8>>> + // CHECK: hw.output %0 } // CHECK-LABEL: hw.module @intrinsic