From d15a62d63efa352447a0f6e2f57533a6c64f73cd Mon Sep 17 00:00:00 2001 From: Tynan McAuley Date: Mon, 19 Feb 2024 09:57:08 -0800 Subject: [PATCH] [FIRRTL] Dedup memory wrapper modules in LowerMemory Instead of just dedup-ing the external memory module, we include the memory wrapper module in the dedup calculation. Resolves #6445. --- lib/Dialect/FIRRTL/Transforms/LowerMemory.cpp | 124 ++++++++++-------- test/Dialect/FIRRTL/lower-memory.mlir | 18 +-- 2 files changed, 77 insertions(+), 65 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/LowerMemory.cpp b/lib/Dialect/FIRRTL/Transforms/LowerMemory.cpp index 5a6a722aa92f..4f6ec4e71504 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerMemory.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerMemory.cpp @@ -93,6 +93,12 @@ FirMemory getSummary(MemOp op) { } namespace { +struct MemOps { + FModuleOp wrapperModule; + FMemModuleOp memModule; + InstanceOp memInst; +}; + struct LowerMemoryPass : public LowerMemoryBase { /// Get the cached namespace for a module. @@ -101,11 +107,11 @@ struct LowerMemoryPass : public LowerMemoryBase { } SmallVector getMemoryModulePorts(const FirMemory &mem); - FMemModuleOp emitMemoryModule(MemOp op, const FirMemory &summary, - const SmallVectorImpl &ports); - FMemModuleOp getOrCreateMemModule(MemOp op, const FirMemory &summary, - const SmallVectorImpl &ports, - bool shouldDedup); + MemOps emitMemoryModule(MemOp op, const FirMemory &mem, + const SmallVectorImpl &ports); + MemOps getOrCreateMemModule(MemOp op, const FirMemory &summary, + const SmallVectorImpl &ports, + bool shouldDedup); FModuleOp createWrapperModule(MemOp op, const FirMemory &summary, bool shouldDedup); InstanceOp emitMemoryInstance(MemOp op, FModuleOp module, @@ -121,7 +127,7 @@ struct LowerMemoryPass : public LowerMemoryBase { /// The set of all memories seen so far. This is used to "deduplicate" /// memories by emitting modules one module for equivalent memories. - std::map memories; + std::map memories; }; } // end anonymous namespace @@ -178,25 +184,61 @@ LowerMemoryPass::getMemoryModulePorts(const FirMemory &mem) { return ports; } -FMemModuleOp +MemOps LowerMemoryPass::emitMemoryModule(MemOp op, const FirMemory &mem, const SmallVectorImpl &ports) { - // Get a non-colliding name for the memory module, and update the summary. - auto newName = circuitNamespace.newName(mem.modName.getValue(), "ext"); - auto moduleName = StringAttr::get(&getContext(), newName); + auto *context = &getContext(); - // Insert the memory module at the bottom of the circuit. + // Get non-colliding names for the memory module and its wrapper. + auto newMemName = circuitNamespace.newName(mem.modName.getValue(), "ext"); + auto memModuleName = StringAttr::get(context, newMemName); + + auto newWrapperName = circuitNamespace.newName(op.getName()); + auto wrapperName = StringAttr::get(context, newWrapperName); + + // Insert the memory module and its wrapper at the bottom of the circuit. auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock()); + + // Create the wrapper module. + auto wrapper = b.create( + op->getLoc(), wrapperName, + ConventionAttr::get(context, Convention::Internal), ports); + SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private); + + // Create the external memory module. ++numCreatedMemModules; - auto moduleOp = b.create( - mem.loc, moduleName, ports, mem.numReadPorts, mem.numWritePorts, + auto memModule = b.create( + mem.loc, memModuleName, ports, mem.numReadPorts, mem.numWritePorts, mem.numReadWritePorts, mem.dataWidth, mem.maskBits, mem.readLatency, mem.writeLatency, mem.depth); - SymbolTable::setSymbolVisibility(moduleOp, SymbolTable::Visibility::Private); - return moduleOp; + SymbolTable::setSymbolVisibility(memModule, SymbolTable::Visibility::Private); + + // Create an instance of the external memory module inside the wrapper + // module. + b.setInsertionPointToStart(wrapper.getBodyBlock()); + auto memInst = + b.create(op->getLoc(), memModule, memModule.getModuleName(), + op.getNameKind(), op.getAnnotations().getValue()); + + // Wire all the ports together. + for (auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(), + memInst.getResults())) { + if (wrapper.getPortDirection(dst.getArgNumber()) == Direction::Out) + b.create(op->getLoc(), dst, src); + else + b.create(op->getLoc(), src, dst); + } + + // Remove all NLAs from the instance, we'll fix them up later. + auto nonlocalAttr = StringAttr::get(context, "circt.nonlocal"); + AnnotationSet::removeAnnotations(memInst, [&](Annotation anno) -> bool { + return anno.getMember(nonlocalAttr) != nullptr; + }); + + return {wrapper, memModule, memInst}; } -FMemModuleOp +MemOps LowerMemoryPass::getOrCreateMemModule(MemOp op, const FirMemory &summary, const SmallVectorImpl &ports, bool shouldDedup) { @@ -210,14 +252,14 @@ LowerMemoryPass::getOrCreateMemModule(MemOp op, const FirMemory &summary, // Create a new module for this memory. This can update the name recorded in // the memory's summary. - auto module = emitMemoryModule(op, summary, ports); + auto modules = emitMemoryModule(op, summary, ports); // Record the memory module. We don't want to use this module for other // memories, then we don't add it to the table. if (shouldDedup) - memories[summary] = module; + memories[summary] = modules; - return module; + return modules; } void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary, @@ -225,34 +267,10 @@ void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary, auto *context = &getContext(); auto ports = getMemoryModulePorts(summary); - // Get a non-colliding name for the memory module, and update the summary. - auto newName = circuitNamespace.newName(mem.getName()); - auto wrapperName = StringAttr::get(&getContext(), newName); - - // Create the wrapper module, inserting it into the bottom of the circuit. - auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock()); - auto wrapper = b.create( - mem->getLoc(), wrapperName, - ConventionAttr::get(context, Convention::Internal), ports); - SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private); - - // Create an instance of the external memory module. The instance has the - // same name as the target module. - auto memModule = getOrCreateMemModule(mem, summary, ports, shouldDedup); - b.setInsertionPointToStart(wrapper.getBodyBlock()); - - auto memInst = - b.create(mem->getLoc(), memModule, memModule.getModuleName(), - mem.getNameKind(), mem.getAnnotations().getValue()); - - // Wire all the ports together. - for (auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(), - memInst.getResults())) { - if (wrapper.getPortDirection(dst.getArgNumber()) == Direction::Out) - b.create(mem->getLoc(), dst, src); - else - b.create(mem->getLoc(), src, dst); - } + // If they haven't been created yet, generate modules for the external memory + // module and its wrapper module. + auto [wrapper, memModule, memInst] = + getOrCreateMemModule(mem, summary, ports, shouldDedup); // Create an instance of the wrapper memory module, which will replace the // original mem op. @@ -272,11 +290,14 @@ void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary, SmallVector newMemModAnnos; OpBuilder nlaBuilder(context); - AnnotationSet::removeAnnotations(memInst, [&](Annotation anno) -> bool { - // We're only looking for non-local annotations. + // Copy all NLAs from the original `MemOp` to the new external memory + // instance `memInst`. + AnnotationSet annos(mem); + for (auto anno : annos) { auto nlaSym = anno.getMember(nonlocalAttr); if (!nlaSym) - return false; + continue; + // If we have already seen this NLA, don't re-process it. auto newNLAIter = processedNLAs.find(nlaSym.getAttr()); StringAttr newNLAName; @@ -306,8 +327,7 @@ void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary, anno.setMember("circt.nonlocal", FlatSymbolRefAttr::get(newNLAName)); nlaUpdated = true; newMemModAnnos.push_back(anno); - return true; - }); + } if (nlaUpdated) { memInst.setInnerSymAttr(hw::InnerSymAttr::get(leafSym)); AnnotationSet newAnnos(memInst); diff --git a/test/Dialect/FIRRTL/lower-memory.mlir b/test/Dialect/FIRRTL/lower-memory.mlir index 9dd42d06af10..a168a2af87fe 100644 --- a/test/Dialect/FIRRTL/lower-memory.mlir +++ b/test/Dialect/FIRRTL/lower-memory.mlir @@ -68,15 +68,12 @@ firrtl.module @Dedup() { %mem0_write = firrtl.mem Undefined {depth = 12 : i64, name = "mem0", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>> %mem1_write = firrtl.mem Undefined {depth = 12 : i64, name = "mem1", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>> // CHECK: firrtl.instance mem0 @mem0( - // CHECK: firrtl.instance mem1 @mem1( + // CHECK: firrtl.instance mem1 @mem0( } // CHECK: firrtl.module private @mem0 // CHECK-NEXT: firrtl.instance mem0_ext @mem0_ext // CHECK: firrtl.memmodule private @mem0_ext - -// CHECK: firrtl.module private @mem1 -// CHECK-NEXT: firrtl.instance mem0_ext @mem0_ext } // Test that memories in the testharness are not deduped with other memories in @@ -240,9 +237,9 @@ firrtl.module @Annotations() attributes {annotations = [{class = "sifive.enterpr // CHECK-LABEL: firrtl.circuit "NonLocalAnnotation" firrtl.circuit "NonLocalAnnotation" { -// CHECK: hw.hierpath private @[[nla_0:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM0:.+]], @mem0] +// CHECK: hw.hierpath private @[[nla0_0:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM0:.+]], @mem0] hw.hierpath private @nla0 [@NonLocalAnnotation::@dut, @DUT::@mem0] -// CHECK: hw.hierpath private @[[nla_1:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM1:.+]], @mem1] +// CHECK: hw.hierpath private @[[nla1_0:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM1:.+]], @mem0] hw.hierpath private @nla1 [@NonLocalAnnotation::@dut, @DUT] // CHECK: firrtl.module @NonLocalAnnotation() @@ -256,7 +253,7 @@ firrtl.module @DUT() { %mem0_write = firrtl.mem sym @mem0 Undefined {annotations = [{circt.nonlocal = @nla0, class = "test0"}], depth = 12 : i64, name = "mem0", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>> // This memory does not have a symbol already attached. - // CHECK: firrtl.instance mem1 sym @[[MEM1]] @mem1 + // CHECK: firrtl.instance mem1 sym @[[MEM1]] @mem0 %mem1_write = firrtl.mem Undefined {annotations = [{circt.nonlocal = @nla1, class = "test1"}], depth = 12 : i64, name = "mem1", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>> // LowerMemory should ignore MemOps that are not seqmems. The following memory is a combmem with readLatency=0. @@ -266,12 +263,7 @@ firrtl.module @DUT() { // CHECK: firrtl.module private @mem0 // CHECK: firrtl.instance mem0_ext sym @mem0_ext -// CHECK-SAME: {annotations = [{circt.nonlocal = @[[nla_0]], class = "test0"}]} +// CHECK-SAME: {annotations = [{circt.nonlocal = @[[nla0_0]], class = "test0"}, {circt.nonlocal = @[[nla1_0]], class = "test1"}]} // CHECK-SAME: @mem0_ext( // CHECK: } - -// CHECK: firrtl.module private @mem1 -// CHECK: firrtl.instance mem0_ext sym @mem0_ext -// CHECK-SAME: {annotations = [{circt.nonlocal = @[[nla_1]], class = "test1"}]} -// CHECK-SAME: @mem0_ext( }