Skip to content

Commit

Permalink
[HW][FlattenIO] Fix extern module instances (#6441)
Browse files Browse the repository at this point in the history
There was a bug in FlattenIO, it was crashing for instances of hwModuleExternOp.
If an instance op is flattened before the corresponding module op, then it results in an inconsistent IR.
Fix this issue, by using an explicit InstanceOp builder that doesn't rely on the module op.
Add a flag to control the external module flattening and also enable recursive by default.
  • Loading branch information
prithayan authored Nov 28, 2023
1 parent 36b7c56 commit 02f7a50
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 70 deletions.
3 changes: 2 additions & 1 deletion include/circt/Dialect/HW/HWPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace hw {
std::unique_ptr<mlir::Pass> createPrintInstanceGraphPass();
std::unique_ptr<mlir::Pass> createHWSpecializePass();
std::unique_ptr<mlir::Pass> createPrintHWModuleGraphPass();
std::unique_ptr<mlir::Pass> createFlattenIOPass();
std::unique_ptr<mlir::Pass> createFlattenIOPass(bool recursiveFlag = true,
bool flattenExternFlag = false);
std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();

/// Generate the code for registering passes.
Expand Down
4 changes: 3 additions & 1 deletion include/circt/Dialect/HW/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def FlattenIO : Pass<"hw-flatten-io", "mlir::ModuleOp"> {
let constructor = "circt::hw::createFlattenIOPass()";

let options = [
Option<"recursive", "recursive", "bool", "false",
Option<"recursive", "recursive", "bool", "true",
"Recursively flatten nested structs.">,
Option<"flattenExtern", "flatten-extern", "bool", "false",
"Flatten the extern modules also.">,
];
}

Expand Down
128 changes: 95 additions & 33 deletions lib/Dialect/HW/Transforms/FlattenIO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,20 @@ struct OutputOpConversion : public OpConversionPattern<hw::OutputOp> {

struct InstanceOpConversion : public OpConversionPattern<hw::InstanceOp> {
InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context,
DenseSet<hw::InstanceOp> *convertedOps)
: OpConversionPattern(typeConverter, context),
convertedOps(convertedOps) {}
DenseSet<hw::InstanceOp> *convertedOps,
const StringSet<> *externModules)
: OpConversionPattern(typeConverter, context), convertedOps(convertedOps),
externModules(externModules) {}

LogicalResult
matchAndRewrite(hw::InstanceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto referencedMod = op.getReferencedModuleNameAttr();
// If externModules is populated and this is an extern module instance,
// donot flatten it.
if (externModules->contains(referencedMod.getValue()))
return success();

auto loc = op.getLoc();
// Flatten the operands.
llvm::SmallVector<Value> convOperands;
Expand All @@ -93,11 +100,23 @@ struct InstanceOpConversion : public OpConversionPattern<hw::InstanceOp> {
}
}

// Create the new instance...
Operation *targetModule = SymbolTable::lookupNearestSymbolFrom(
op, op.getReferencedModuleNameAttr());
// Get the new module return type.
llvm::SmallVector<Type> newResultTypes;
for (auto oldResultType : op.getResultTypes()) {
if (auto structType = getStructType(oldResultType))
for (auto t : structType.getElements())
newResultTypes.push_back(t.type);
else
newResultTypes.push_back(oldResultType);
}

// Create the new instance with the flattened module, attributes will be
// adjusted later.
auto newInstance = rewriter.create<hw::InstanceOp>(
loc, targetModule, op.getInstanceName(), convOperands);
loc, newResultTypes, op.getInstanceNameAttr(),
FlatSymbolRefAttr::get(referencedMod), convOperands,
op.getArgNamesAttr(), op.getResultNamesAttr(), op.getParametersAttr(),
op.getInnerSymAttr());

// re-create any structs in the result.
llvm::SmallVector<Value> convResults;
Expand All @@ -123,6 +142,7 @@ struct InstanceOpConversion : public OpConversionPattern<hw::InstanceOp> {
}

DenseSet<hw::InstanceOp> *convertedOps;
const StringSet<> *externModules;
};

using IOTypes = std::pair<TypeRange, TypeRange>;
Expand All @@ -144,6 +164,7 @@ class FlattenIOTypeConverter : public TypeConverter {
results.push_back(type);
else {
for (auto field : structType.getElements())

results.push_back(field.type);
}
return success();
Expand Down Expand Up @@ -290,28 +311,33 @@ updateBlockLocations(hw::HWModuleLike op,
arg.setLoc(loc);
}

static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo) {
ioInfo.argTypes = op.getInputTypes();
ioInfo.resTypes = op.getOutputTypes();
for (auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) {
if (auto structType = getStructType(arg))
ioInfo.argStructs[i] = structType;
}
for (auto [i, res] : llvm::enumerate(ioInfo.resTypes)) {
if (auto structType = getStructType(res))
ioInfo.resStructs[i] = structType;
}
}

template <typename T>
static DenseMap<Operation *, IOInfo> populateIOInfoMap(mlir::ModuleOp module) {
DenseMap<Operation *, IOInfo> ioInfoMap;
for (auto op : module.getOps<T>()) {
IOInfo ioInfo;
ioInfo.argTypes = op.getInputTypes();
ioInfo.resTypes = op.getOutputTypes();
for (auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) {
if (auto structType = getStructType(arg))
ioInfo.argStructs[i] = structType;
}
for (auto [i, res] : llvm::enumerate(ioInfo.resTypes)) {
if (auto structType = getStructType(res))
ioInfo.resStructs[i] = structType;
}
setIOInfo(op, ioInfo);
ioInfoMap[op] = ioInfo;
}
return ioInfoMap;
}

template <typename T>
static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive) {
static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive,
StringSet<> &externModules) {
auto *ctx = module.getContext();
FlattenIOTypeConverter typeConverter;

Expand All @@ -338,13 +364,16 @@ static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive) {
DenseSet<Operation *> opVisited;
patterns.add<OutputOpConversion>(typeConverter, ctx, &opVisited);

patterns.add<InstanceOpConversion>(typeConverter, ctx, &convertedInstances);
patterns.add<InstanceOpConversion>(typeConverter, ctx, &convertedInstances,
&externModules);
target.addDynamicallyLegalOp<hw::OutputOp>(
[&](auto op) { return opVisited.contains(op->getParentOp()); });
target.addDynamicallyLegalOp<hw::InstanceOp>([&](auto op) {
return llvm::none_of(op->getOperands(), [](auto operand) {
return isStructType(operand.getType());
});
target.addDynamicallyLegalOp<hw::InstanceOp>([&](hw::InstanceOp op) {
auto refName = op.getReferencedModuleName();
return externModules.contains(refName) ||
llvm::none_of(op->getOperands(), [](auto operand) {
return isStructType(operand.getType());
});
});

DenseMap<Operation *, ArrayAttr> oldArgNames, oldResNames, oldArgLocs,
Expand Down Expand Up @@ -383,10 +412,24 @@ static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive) {

// And likewise with the converted instance ops.
for (auto instanceOp : convertedInstances) {
Operation *targetModule = SymbolTable::lookupNearestSymbolFrom(
instanceOp, instanceOp.getReferencedModuleNameAttr());
auto targetModule =
cast<hw::HWModuleLike>(SymbolTable::lookupNearestSymbolFrom(
instanceOp, instanceOp.getReferencedModuleNameAttr()));

IOInfo ioInfo;
if (!ioInfoMap.contains(targetModule)) {
// If an extern module, then not yet processed, populate the maps.
setIOInfo(targetModule, ioInfo);
ioInfoMap[targetModule] = ioInfo;
oldArgNames[targetModule] =
ArrayAttr::get(module.getContext(), targetModule.getInputNames());
oldResNames[targetModule] =
ArrayAttr::get(module.getContext(), targetModule.getOutputNames());
oldArgLocs[targetModule] = targetModule.getInputLocsAttr();
oldResLocs[targetModule] = targetModule.getOutputLocsAttr();
} else
ioInfo = ioInfoMap[targetModule];

auto ioInfo = ioInfoMap[targetModule];
instanceOp.setInputNames(ArrayAttr::get(
instanceOp.getContext(),
updateNameAttribute(instanceOp, "argNames", ioInfo.argStructs,
Expand All @@ -397,7 +440,6 @@ static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive) {
updateNameAttribute(instanceOp, "resultNames", ioInfo.resStructs,
oldResNames[targetModule]
.template getAsValueRange<StringAttr>())));
instanceOp.dump();
}

// Break if we've only lowering a single level of structs.
Expand All @@ -412,28 +454,48 @@ static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive) {
//===----------------------------------------------------------------------===//

template <typename... TOps>
static bool flattenIO(ModuleOp module, bool recursive) {
return (failed(flattenOpsOfType<TOps>(module, recursive)) || ...);
static bool flattenIO(ModuleOp module, bool recursive,
StringSet<> &externModules) {
return (failed(flattenOpsOfType<TOps>(module, recursive, externModules)) ||
...);
}

namespace {

class FlattenIOPass : public circt::hw::FlattenIOBase<FlattenIOPass> {
public:
FlattenIOPass(bool recursiveFlag, bool flattenExternFlag) {
recursive = recursiveFlag;
flattenExtern = flattenExternFlag;
}

void runOnOperation() override {
ModuleOp module = getOperation();
if (!flattenExtern) {
// Record the extern modules, donot flatten them.
for (auto m : module.getOps<hw::HWModuleExternOp>())
externModules.insert(m.getModuleName());
if (flattenIO<hw::HWModuleOp, hw::HWModuleGeneratedOp>(module, recursive,
externModules))
signalPassFailure();
return;
}

if (flattenIO<hw::HWModuleOp, hw::HWModuleExternOp,
hw::HWModuleGeneratedOp>(module, recursive))
hw::HWModuleGeneratedOp>(module, recursive, externModules))
signalPassFailure();
};
};

private:
StringSet<> externModules;
};
} // namespace

//===----------------------------------------------------------------------===//
// Pass initialization
//===----------------------------------------------------------------------===//

std::unique_ptr<Pass> circt::hw::createFlattenIOPass() {
return std::make_unique<FlattenIOPass>();
std::unique_ptr<Pass> circt::hw::createFlattenIOPass(bool recursiveFlag,
bool flattenExternFlag) {
return std::make_unique<FlattenIOPass>(true, flattenExternFlag);
}
84 changes: 49 additions & 35 deletions test/Dialect/HW/flatten-io.mlir
Original file line number Diff line number Diff line change
@@ -1,66 +1,80 @@
// RUN: circt-opt --hw-flatten-io="recursive=true" %s | FileCheck %s
// RUN: circt-opt --hw-flatten-io %s | FileCheck %s -check-prefix BASIC
// RUN: circt-opt --hw-flatten-io="flatten-extern=true" %s | FileCheck %s -check-prefix EXTERN

// Ensure that non-struct-using modules pass cleanly through the pass.

// CHECK-LABEL: hw.module @level0(in %arg0 : i32, out out0 : i32) {
// CHECK-NEXT: hw.output %arg0 : i32
// CHECK-NEXT: }
// BASIC-LABEL: hw.module @level0(in %arg0 : i32, out out0 : i32) {
// BASIC-NEXT: hw.output %arg0 : i32
// BASIC-NEXT: }
hw.module @level0(in %arg0 : i32, out out0 : i32) {
hw.output %arg0: i32
}

// CHECK-LABEL: hw.module @level1(in %arg0 : i32, in %in.a : i1, in %in.b : i2, in %arg1 : i32, out out0 : i32, out out.a : i1, out out.b : i2, out out1 : i32) {
// CHECK-NEXT: %0 = hw.struct_create (%in.a, %in.b) : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: %a, %b = hw.struct_explode %0 : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: hw.output %arg0, %a, %b, %arg1 : i32, i1, i2, i32
// CHECK-NEXT: }
// BASIC-LABEL: hw.module @level1(in %arg0 : i32, in %in.a : i1, in %in.b : i2, in %arg1 : i32, out out0 : i32, out out.a : i1, out out.b : i2, out out1 : i32) {
// BASIC-NEXT: %0 = hw.struct_create (%in.a, %in.b) : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: %a, %b = hw.struct_explode %0 : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: hw.output %arg0, %a, %b, %arg1 : i32, i1, i2, i32
// BASIC-NEXT: }
!Struct1 = !hw.struct<a: i1, b: i2>
hw.module @level1(in %arg0 : i32, in %in : !Struct1, in %arg1: i32, out out0 : i32, out out: !Struct1, out out1: i32) {
hw.output %arg0, %in, %arg1 : i32, !Struct1, i32
}

// CHECK-LABEL: hw.module @level2(in %in.aa.a : i1, in %in.aa.b : i2, in %in.bb.a : i1, in %in.bb.b : i2, out out.aa.a : i1, out out.aa.b : i2, out out.bb.a : i1, out out.bb.b : i2) {
// CHECK-NEXT: %0 = hw.struct_create (%in.aa.a, %in.aa.b) : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: %1 = hw.struct_create (%in.bb.a, %in.bb.b) : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: %2 = hw.struct_create (%0, %1) : !hw.struct<aa: !hw.struct<a: i1, b: i2>, bb: !hw.struct<a: i1, b: i2>>
// CHECK-NEXT: %aa, %bb = hw.struct_explode %2 : !hw.struct<aa: !hw.struct<a: i1, b: i2>, bb: !hw.struct<a: i1, b: i2>>
// CHECK-NEXT: %a, %b = hw.struct_explode %aa : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: %a_0, %b_1 = hw.struct_explode %bb : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: hw.output %a, %b, %a_0, %b_1 : i1, i2, i1, i2
// CHECK-NEXT: }
// BASIC-LABEL: hw.module @level2(in %in.aa.a : i1, in %in.aa.b : i2, in %in.bb.a : i1, in %in.bb.b : i2, out out.aa.a : i1, out out.aa.b : i2, out out.bb.a : i1, out out.bb.b : i2) {
// BASIC-NEXT: %0 = hw.struct_create (%in.aa.a, %in.aa.b) : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: %1 = hw.struct_create (%in.bb.a, %in.bb.b) : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: %2 = hw.struct_create (%0, %1) : !hw.struct<aa: !hw.struct<a: i1, b: i2>, bb: !hw.struct<a: i1, b: i2>>
// BASIC-NEXT: %aa, %bb = hw.struct_explode %2 : !hw.struct<aa: !hw.struct<a: i1, b: i2>, bb: !hw.struct<a: i1, b: i2>>
// BASIC-NEXT: %a, %b = hw.struct_explode %aa : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: %a_0, %b_1 = hw.struct_explode %bb : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: hw.output %a, %b, %a_0, %b_1 : i1, i2, i1, i2
// BASIC-NEXT: }
!Struct2 = !hw.struct<aa: !Struct1, bb: !Struct1>
hw.module @level2(in %in : !Struct2, out out: !Struct2) {
hw.output %in : !Struct2
}

// CHECK-LABEL: hw.module.extern @level1_extern(in %arg0 : i32, in %in.a : i1, in %in.b : i2, in %arg1 : i32, out out0 : i32, out out.a : i1, out out.b : i2, out out1 : i32)
hw.module.extern @level1_extern(in %arg0 : i32, in %in : !Struct1, in %arg1: i32, out out0 : i32, out out: !Struct1, out out1: i32)


hw.type_scope @foo {
hw.typedecl @bar : !Struct1
}
!ScopedStruct = !hw.typealias<@foo::@bar,!Struct1>

// CHECK-LABEL: hw.module @scoped(in %arg0 : i32, in %in.a : i1, in %in.b : i2, in %arg1 : i32, out out0 : i32, out out.a : i1, out out.b : i2, out out1 : i32) {
// CHECK-NEXT: %0 = hw.struct_create (%in.a, %in.b) : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: %a, %b = hw.struct_explode %0 : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: hw.output %arg0, %a, %b, %arg1 : i32, i1, i2, i32
// CHECK-NEXT: }
// BASIC-LABEL: hw.module @scoped(in %arg0 : i32, in %in.a : i1, in %in.b : i2, in %arg1 : i32, out out0 : i32, out out.a : i1, out out.b : i2, out out1 : i32) {
// BASIC-NEXT: %0 = hw.struct_create (%in.a, %in.b) : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: %a, %b = hw.struct_explode %0 : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: hw.output %arg0, %a, %b, %arg1 : i32, i1, i2, i32
// BASIC-NEXT: }
hw.module @scoped(in %arg0 : i32, in %in : !ScopedStruct, in %arg1: i32, out out0 : i32, out out: !ScopedStruct, out out1: i32) {
hw.output %arg0, %in, %arg1 : i32, !ScopedStruct, i32
}

// CHECK-LABEL: hw.module @instance(in %arg0 : i32, in %arg1.a : i1, in %arg1.b : i2, out out.a : i1, out out.b : i2) {
// CHECK-NEXT: %0 = hw.struct_create (%arg1.a, %arg1.b) : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: %a, %b = hw.struct_explode %0 : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: %l1.out0, %l1.out.a, %l1.out.b, %l1.out1 = hw.instance "l1" @level1(arg0: %arg0: i32, in.a: %a: i1, in.b: %b: i2, arg1: %arg0: i32) -> (out0: i32, out.a: i1, out.b: i2, out1: i32)
// CHECK-NEXT: %1 = hw.struct_create (%l1.out.a, %l1.out.b) : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: %a_0, %b_1 = hw.struct_explode %1 : !hw.struct<a: i1, b: i2>
// CHECK-NEXT: hw.output %a_0, %b_1 : i1, i2
// CHECK-NEXT: }
// BASIC-LABEL: hw.module @instance(in %arg0 : i32, in %arg1.a : i1, in %arg1.b : i2, out out.a : i1, out out.b : i2) {
// BASIC-NEXT: %0 = hw.struct_create (%arg1.a, %arg1.b) : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: %a, %b = hw.struct_explode %0 : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: %l1.out0, %l1.out.a, %l1.out.b, %l1.out1 = hw.instance "l1" @level1(arg0: %arg0: i32, in.a: %a: i1, in.b: %b: i2, arg1: %arg0: i32) -> (out0: i32, out.a: i1, out.b: i2, out1: i32)
// BASIC-NEXT: %1 = hw.struct_create (%l1.out.a, %l1.out.b) : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: %a_0, %b_1 = hw.struct_explode %1 : !hw.struct<a: i1, b: i2>
// BASIC-NEXT: hw.output %a_0, %b_1 : i1, i2
// BASIC-NEXT: }
hw.module @instance(in %arg0 : i32, in %arg1 : !Struct1, out out : !Struct1) {
%0:3 = hw.instance "l1" @level1(arg0: %arg0 : i32, in: %arg1 : !Struct1, arg1: %arg0 : i32) -> (out0: i32, out: !Struct1, out1: i32)
hw.output %0#1 : !Struct1
}

// EXTERN-LABEL: hw.module.extern @level1_extern(in %arg0 : i32, in %in.a : i1, in %in.b : i2, in %arg1 : i32, out out0 : i32, out out.a : i1, out out.b : i2, out out1 : i32)
// BASIC-LABEL: hw.module.extern @level1_extern(in %arg0 : i32, in %in : !hw.struct<a: i1, b: i2>, in %arg1 : i32, out out0 : i32, out out : !hw.struct<a: i1, b: i2>, out out1 : i32)
hw.module.extern @level1_extern(in %arg0 : i32, in %in : !Struct1, in %arg1: i32, out out0 : i32, out out: !Struct1, out out1: i32)


// EXTERN-LABEL: hw.module @instance_extern(in %arg0 : i32, in %arg1.a : i1, in %arg1.b : i2, out out.a : i1, out out.b : i2) {
// EXTERN-NEXT: %0 = hw.struct_create (%arg1.a, %arg1.b) : !hw.struct<a: i1, b: i2>
// EXTERN-NEXT: %a, %b = hw.struct_explode %0 : !hw.struct<a: i1, b: i2>
// EXTERN-NEXT: %l1.out0, %l1.out.a, %l1.out.b, %l1.out1 = hw.instance "l1" @level1_extern(arg0: %arg0: i32, in.a: %a: i1, in.b: %b: i2, arg1: %arg0: i32) -> (out0: i32, out.a: i1, out.b: i2, out1: i32)
// EXTERN-NEXT: %1 = hw.struct_create (%l1.out.a, %l1.out.b) : !hw.struct<a: i1, b: i2>
// EXTERN-NEXT: %a_0, %b_1 = hw.struct_explode %1 : !hw.struct<a: i1, b: i2>
// EXTERN-NEXT: hw.output %a_0, %b_1 : i1, i2

hw.module @instance_extern(in %arg0 : i32, in %arg1 : !Struct1, out out : !Struct1) {
%0:3 = hw.instance "l1" @level1_extern(arg0: %arg0 : i32, in: %arg1 : !Struct1, arg1: %arg0 : i32) -> (out0: i32, out: !Struct1, out1: i32)
hw.output %0#1 : !Struct1
}

0 comments on commit 02f7a50

Please sign in to comment.