Skip to content

Commit

Permalink
[FIRRTL] Dedup memory wrapper modules in LowerMemory (llvm#6719)
Browse files Browse the repository at this point in the history
Instead of just dedup-ing the external memory module, include the
memory wrapper module in the dedup calculation.

Resolves llvm#6445.
  • Loading branch information
tymcauley authored Mar 11, 2024
1 parent eb5900a commit 2e23cda
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 65 deletions.
124 changes: 72 additions & 52 deletions lib/Dialect/FIRRTL/Transforms/LowerMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ FirMemory getSummary(MemOp op) {
}

namespace {
struct MemOps {
FModuleOp wrapperModule;
FMemModuleOp memModule;
InstanceOp memInst;
};

struct LowerMemoryPass : public LowerMemoryBase<LowerMemoryPass> {

/// Get the cached namespace for a module.
Expand All @@ -101,11 +107,11 @@ struct LowerMemoryPass : public LowerMemoryBase<LowerMemoryPass> {
}

SmallVector<PortInfo> getMemoryModulePorts(const FirMemory &mem);
FMemModuleOp emitMemoryModule(MemOp op, const FirMemory &summary,
const SmallVectorImpl<PortInfo> &ports);
FMemModuleOp getOrCreateMemModule(MemOp op, const FirMemory &summary,
const SmallVectorImpl<PortInfo> &ports,
bool shouldDedup);
MemOps emitMemoryModule(MemOp op, const FirMemory &mem,
const SmallVectorImpl<PortInfo> &ports);
MemOps getOrCreateMemModule(MemOp op, const FirMemory &summary,
const SmallVectorImpl<PortInfo> &ports,
bool shouldDedup);
FModuleOp createWrapperModule(MemOp op, const FirMemory &summary,
bool shouldDedup);
InstanceOp emitMemoryInstance(MemOp op, FModuleOp module,
Expand All @@ -121,7 +127,7 @@ struct LowerMemoryPass : public LowerMemoryBase<LowerMemoryPass> {

/// The set of all memories seen so far. This is used to "deduplicate"
/// memories by emitting modules one module for equivalent memories.
std::map<FirMemory, FMemModuleOp> memories;
std::map<FirMemory, MemOps> memories;
};
} // end anonymous namespace

Expand Down Expand Up @@ -178,25 +184,61 @@ LowerMemoryPass::getMemoryModulePorts(const FirMemory &mem) {
return ports;
}

FMemModuleOp
MemOps
LowerMemoryPass::emitMemoryModule(MemOp op, const FirMemory &mem,
const SmallVectorImpl<PortInfo> &ports) {
// Get a non-colliding name for the memory module, and update the summary.
auto newName = circuitNamespace.newName(mem.modName.getValue(), "ext");
auto moduleName = StringAttr::get(&getContext(), newName);
auto *context = &getContext();

// Insert the memory module at the bottom of the circuit.
// Get non-colliding names for the memory module and its wrapper.
auto newMemName = circuitNamespace.newName(mem.modName.getValue(), "ext");
auto memModuleName = StringAttr::get(context, newMemName);

auto newWrapperName = circuitNamespace.newName(op.getName());
auto wrapperName = StringAttr::get(context, newWrapperName);

// Insert the memory module and its wrapper at the bottom of the circuit.
auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock());

// Create the wrapper module.
auto wrapper = b.create<FModuleOp>(
op->getLoc(), wrapperName,
ConventionAttr::get(context, Convention::Internal), ports);
SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);

// Create the external memory module.
++numCreatedMemModules;
auto moduleOp = b.create<FMemModuleOp>(
mem.loc, moduleName, ports, mem.numReadPorts, mem.numWritePorts,
auto memModule = b.create<FMemModuleOp>(
mem.loc, memModuleName, ports, mem.numReadPorts, mem.numWritePorts,
mem.numReadWritePorts, mem.dataWidth, mem.maskBits, mem.readLatency,
mem.writeLatency, mem.depth);
SymbolTable::setSymbolVisibility(moduleOp, SymbolTable::Visibility::Private);
return moduleOp;
SymbolTable::setSymbolVisibility(memModule, SymbolTable::Visibility::Private);

// Create an instance of the external memory module inside the wrapper
// module.
b.setInsertionPointToStart(wrapper.getBodyBlock());
auto memInst =
b.create<InstanceOp>(op->getLoc(), memModule, memModule.getModuleName(),
op.getNameKind(), op.getAnnotations().getValue());

// Wire all the ports together.
for (auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(),
memInst.getResults())) {
if (wrapper.getPortDirection(dst.getArgNumber()) == Direction::Out)
b.create<StrictConnectOp>(op->getLoc(), dst, src);
else
b.create<StrictConnectOp>(op->getLoc(), src, dst);
}

// Remove all NLAs from the instance, we'll fix them up later.
auto nonlocalAttr = StringAttr::get(context, "circt.nonlocal");
AnnotationSet::removeAnnotations(memInst, [&](Annotation anno) -> bool {
return anno.getMember<FlatSymbolRefAttr>(nonlocalAttr) != nullptr;
});

return {wrapper, memModule, memInst};
}

FMemModuleOp
MemOps
LowerMemoryPass::getOrCreateMemModule(MemOp op, const FirMemory &summary,
const SmallVectorImpl<PortInfo> &ports,
bool shouldDedup) {
Expand All @@ -210,49 +252,25 @@ LowerMemoryPass::getOrCreateMemModule(MemOp op, const FirMemory &summary,

// Create a new module for this memory. This can update the name recorded in
// the memory's summary.
auto module = emitMemoryModule(op, summary, ports);
auto modules = emitMemoryModule(op, summary, ports);

// Record the memory module. We don't want to use this module for other
// memories, then we don't add it to the table.
if (shouldDedup)
memories[summary] = module;
memories[summary] = modules;

return module;
return modules;
}

void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
bool shouldDedup) {
auto *context = &getContext();
auto ports = getMemoryModulePorts(summary);

// Get a non-colliding name for the memory module, and update the summary.
auto newName = circuitNamespace.newName(mem.getName());
auto wrapperName = StringAttr::get(&getContext(), newName);

// Create the wrapper module, inserting it into the bottom of the circuit.
auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock());
auto wrapper = b.create<FModuleOp>(
mem->getLoc(), wrapperName,
ConventionAttr::get(context, Convention::Internal), ports);
SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);

// Create an instance of the external memory module. The instance has the
// same name as the target module.
auto memModule = getOrCreateMemModule(mem, summary, ports, shouldDedup);
b.setInsertionPointToStart(wrapper.getBodyBlock());

auto memInst =
b.create<InstanceOp>(mem->getLoc(), memModule, memModule.getModuleName(),
mem.getNameKind(), mem.getAnnotations().getValue());

// Wire all the ports together.
for (auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(),
memInst.getResults())) {
if (wrapper.getPortDirection(dst.getArgNumber()) == Direction::Out)
b.create<StrictConnectOp>(mem->getLoc(), dst, src);
else
b.create<StrictConnectOp>(mem->getLoc(), src, dst);
}
// If they haven't been created yet, generate modules for the external memory
// module and its wrapper module.
auto [wrapper, memModule, memInst] =
getOrCreateMemModule(mem, summary, ports, shouldDedup);

// Create an instance of the wrapper memory module, which will replace the
// original mem op.
Expand All @@ -272,11 +290,14 @@ void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
SmallVector<Annotation> newMemModAnnos;
OpBuilder nlaBuilder(context);

AnnotationSet::removeAnnotations(memInst, [&](Annotation anno) -> bool {
// We're only looking for non-local annotations.
// Copy all NLAs from the original `MemOp` to the new external memory
// instance `memInst`.
AnnotationSet annos(mem);
for (auto anno : annos) {
auto nlaSym = anno.getMember<FlatSymbolRefAttr>(nonlocalAttr);
if (!nlaSym)
return false;
continue;

// If we have already seen this NLA, don't re-process it.
auto newNLAIter = processedNLAs.find(nlaSym.getAttr());
StringAttr newNLAName;
Expand Down Expand Up @@ -306,8 +327,7 @@ void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
anno.setMember("circt.nonlocal", FlatSymbolRefAttr::get(newNLAName));
nlaUpdated = true;
newMemModAnnos.push_back(anno);
return true;
});
}
if (nlaUpdated) {
memInst.setInnerSymAttr(hw::InnerSymAttr::get(leafSym));
AnnotationSet newAnnos(memInst);
Expand Down
18 changes: 5 additions & 13 deletions test/Dialect/FIRRTL/lower-memory.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,12 @@ firrtl.module @Dedup() {
%mem0_write = firrtl.mem Undefined {depth = 12 : i64, name = "mem0", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>>
%mem1_write = firrtl.mem Undefined {depth = 12 : i64, name = "mem1", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>>
// CHECK: firrtl.instance mem0 @mem0(
// CHECK: firrtl.instance mem1 @mem1(
// CHECK: firrtl.instance mem1 @mem0(
}
// CHECK: firrtl.module private @mem0
// CHECK-NEXT: firrtl.instance mem0_ext @mem0_ext

// CHECK: firrtl.memmodule private @mem0_ext

// CHECK: firrtl.module private @mem1
// CHECK-NEXT: firrtl.instance mem0_ext @mem0_ext
}

// Test that memories in the testharness are not deduped with other memories in
Expand Down Expand Up @@ -240,9 +237,9 @@ firrtl.module @Annotations() attributes {annotations = [{class = "sifive.enterpr
// CHECK-LABEL: firrtl.circuit "NonLocalAnnotation"
firrtl.circuit "NonLocalAnnotation" {

// CHECK: hw.hierpath private @[[nla_0:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM0:.+]], @mem0]
// CHECK: hw.hierpath private @[[nla0_0:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM0:.+]], @mem0]
hw.hierpath private @nla0 [@NonLocalAnnotation::@dut, @DUT::@mem0]
// CHECK: hw.hierpath private @[[nla_1:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM1:.+]], @mem1]
// CHECK: hw.hierpath private @[[nla1_0:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM1:.+]], @mem0]
hw.hierpath private @nla1 [@NonLocalAnnotation::@dut, @DUT]

// CHECK: firrtl.module @NonLocalAnnotation()
Expand All @@ -256,7 +253,7 @@ firrtl.module @DUT() {
%mem0_write = firrtl.mem sym @mem0 Undefined {annotations = [{circt.nonlocal = @nla0, class = "test0"}], depth = 12 : i64, name = "mem0", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>>

// This memory does not have a symbol already attached.
// CHECK: firrtl.instance mem1 sym @[[MEM1]] @mem1
// CHECK: firrtl.instance mem1 sym @[[MEM1]] @mem0
%mem1_write = firrtl.mem Undefined {annotations = [{circt.nonlocal = @nla1, class = "test1"}], depth = 12 : i64, name = "mem1", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>>

// LowerMemory should ignore MemOps that are not seqmems. The following memory is a combmem with readLatency=0.
Expand All @@ -266,12 +263,7 @@ firrtl.module @DUT() {

// CHECK: firrtl.module private @mem0
// CHECK: firrtl.instance mem0_ext sym @mem0_ext
// CHECK-SAME: {annotations = [{circt.nonlocal = @[[nla_0]], class = "test0"}]}
// CHECK-SAME: {annotations = [{circt.nonlocal = @[[nla0_0]], class = "test0"}, {circt.nonlocal = @[[nla1_0]], class = "test1"}]}
// CHECK-SAME: @mem0_ext(
// CHECK: }

// CHECK: firrtl.module private @mem1
// CHECK: firrtl.instance mem0_ext sym @mem0_ext
// CHECK-SAME: {annotations = [{circt.nonlocal = @[[nla_1]], class = "test1"}]}
// CHECK-SAME: @mem0_ext(
}

0 comments on commit 2e23cda

Please sign in to comment.