Skip to content

Commit

Permalink
Revert "[FIRRTL] Dedup memory wrapper modules in LowerMemory (llvm#6719
Browse files Browse the repository at this point in the history
…)"

This reverts commit 2e23cda.

This change caused memories that were previously deduping to no
longer dedupe, so we're reverting this while we investigate. An
issue will be created to include a small reproducer and track
re-landing this change.
  • Loading branch information
mikeurbach committed Mar 14, 2024
1 parent a25a583 commit bea8510
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 77 deletions.
124 changes: 52 additions & 72 deletions lib/Dialect/FIRRTL/Transforms/LowerMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ FirMemory getSummary(MemOp op) {
}

namespace {
struct MemOps {
FModuleOp wrapperModule;
FMemModuleOp memModule;
InstanceOp memInst;
};

struct LowerMemoryPass : public LowerMemoryBase<LowerMemoryPass> {

/// Get the cached namespace for a module.
Expand All @@ -107,11 +101,11 @@ struct LowerMemoryPass : public LowerMemoryBase<LowerMemoryPass> {
}

SmallVector<PortInfo> getMemoryModulePorts(const FirMemory &mem);
MemOps emitMemoryModule(MemOp op, const FirMemory &mem,
const SmallVectorImpl<PortInfo> &ports);
MemOps getOrCreateMemModule(MemOp op, const FirMemory &summary,
const SmallVectorImpl<PortInfo> &ports,
bool shouldDedup);
FMemModuleOp emitMemoryModule(MemOp op, const FirMemory &summary,
const SmallVectorImpl<PortInfo> &ports);
FMemModuleOp getOrCreateMemModule(MemOp op, const FirMemory &summary,
const SmallVectorImpl<PortInfo> &ports,
bool shouldDedup);
FModuleOp createWrapperModule(MemOp op, const FirMemory &summary,
bool shouldDedup);
InstanceOp emitMemoryInstance(MemOp op, FModuleOp module,
Expand All @@ -127,7 +121,7 @@ struct LowerMemoryPass : public LowerMemoryBase<LowerMemoryPass> {

/// The set of all memories seen so far. This is used to "deduplicate"
/// memories by emitting modules one module for equivalent memories.
std::map<FirMemory, MemOps> memories;
std::map<FirMemory, FMemModuleOp> memories;
};
} // end anonymous namespace

Expand Down Expand Up @@ -184,61 +178,25 @@ LowerMemoryPass::getMemoryModulePorts(const FirMemory &mem) {
return ports;
}

MemOps
FMemModuleOp
LowerMemoryPass::emitMemoryModule(MemOp op, const FirMemory &mem,
const SmallVectorImpl<PortInfo> &ports) {
auto *context = &getContext();
// Get a non-colliding name for the memory module, and update the summary.
auto newName = circuitNamespace.newName(mem.modName.getValue(), "ext");
auto moduleName = StringAttr::get(&getContext(), newName);

// Get non-colliding names for the memory module and its wrapper.
auto newMemName = circuitNamespace.newName(mem.modName.getValue(), "ext");
auto memModuleName = StringAttr::get(context, newMemName);

auto newWrapperName = circuitNamespace.newName(op.getName());
auto wrapperName = StringAttr::get(context, newWrapperName);

// Insert the memory module and its wrapper at the bottom of the circuit.
// Insert the memory module at the bottom of the circuit.
auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock());

// Create the wrapper module.
auto wrapper = b.create<FModuleOp>(
op->getLoc(), wrapperName,
ConventionAttr::get(context, Convention::Internal), ports);
SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);

// Create the external memory module.
++numCreatedMemModules;
auto memModule = b.create<FMemModuleOp>(
mem.loc, memModuleName, ports, mem.numReadPorts, mem.numWritePorts,
auto moduleOp = b.create<FMemModuleOp>(
mem.loc, moduleName, ports, mem.numReadPorts, mem.numWritePorts,
mem.numReadWritePorts, mem.dataWidth, mem.maskBits, mem.readLatency,
mem.writeLatency, mem.depth);
SymbolTable::setSymbolVisibility(memModule, SymbolTable::Visibility::Private);

// Create an instance of the external memory module inside the wrapper
// module.
b.setInsertionPointToStart(wrapper.getBodyBlock());
auto memInst =
b.create<InstanceOp>(op->getLoc(), memModule, memModule.getModuleName(),
op.getNameKind(), op.getAnnotations().getValue());

// Wire all the ports together.
for (auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(),
memInst.getResults())) {
if (wrapper.getPortDirection(dst.getArgNumber()) == Direction::Out)
b.create<StrictConnectOp>(op->getLoc(), dst, src);
else
b.create<StrictConnectOp>(op->getLoc(), src, dst);
}

// Remove all NLAs from the instance, we'll fix them up later.
auto nonlocalAttr = StringAttr::get(context, "circt.nonlocal");
AnnotationSet::removeAnnotations(memInst, [&](Annotation anno) -> bool {
return anno.getMember<FlatSymbolRefAttr>(nonlocalAttr) != nullptr;
});

return {wrapper, memModule, memInst};
SymbolTable::setSymbolVisibility(moduleOp, SymbolTable::Visibility::Private);
return moduleOp;
}

MemOps
FMemModuleOp
LowerMemoryPass::getOrCreateMemModule(MemOp op, const FirMemory &summary,
const SmallVectorImpl<PortInfo> &ports,
bool shouldDedup) {
Expand All @@ -252,25 +210,49 @@ LowerMemoryPass::getOrCreateMemModule(MemOp op, const FirMemory &summary,

// Create a new module for this memory. This can update the name recorded in
// the memory's summary.
auto modules = emitMemoryModule(op, summary, ports);
auto module = emitMemoryModule(op, summary, ports);

// Record the memory module. We don't want to use this module for other
// memories, then we don't add it to the table.
if (shouldDedup)
memories[summary] = modules;
memories[summary] = module;

return modules;
return module;
}

void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
bool shouldDedup) {
auto *context = &getContext();
auto ports = getMemoryModulePorts(summary);

// If they haven't been created yet, generate modules for the external memory
// module and its wrapper module.
auto [wrapper, memModule, memInst] =
getOrCreateMemModule(mem, summary, ports, shouldDedup);
// Get a non-colliding name for the memory module, and update the summary.
auto newName = circuitNamespace.newName(mem.getName());
auto wrapperName = StringAttr::get(&getContext(), newName);

// Create the wrapper module, inserting it into the bottom of the circuit.
auto b = OpBuilder::atBlockEnd(getOperation().getBodyBlock());
auto wrapper = b.create<FModuleOp>(
mem->getLoc(), wrapperName,
ConventionAttr::get(context, Convention::Internal), ports);
SymbolTable::setSymbolVisibility(wrapper, SymbolTable::Visibility::Private);

// Create an instance of the external memory module. The instance has the
// same name as the target module.
auto memModule = getOrCreateMemModule(mem, summary, ports, shouldDedup);
b.setInsertionPointToStart(wrapper.getBodyBlock());

auto memInst =
b.create<InstanceOp>(mem->getLoc(), memModule, memModule.getModuleName(),
mem.getNameKind(), mem.getAnnotations().getValue());

// Wire all the ports together.
for (auto [dst, src] : llvm::zip(wrapper.getBodyBlock()->getArguments(),
memInst.getResults())) {
if (wrapper.getPortDirection(dst.getArgNumber()) == Direction::Out)
b.create<StrictConnectOp>(mem->getLoc(), dst, src);
else
b.create<StrictConnectOp>(mem->getLoc(), src, dst);
}

// Create an instance of the wrapper memory module, which will replace the
// original mem op.
Expand All @@ -290,14 +272,11 @@ void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
SmallVector<Annotation> newMemModAnnos;
OpBuilder nlaBuilder(context);

// Copy all NLAs from the original `MemOp` to the new external memory
// instance `memInst`.
AnnotationSet annos(mem);
for (auto anno : annos) {
AnnotationSet::removeAnnotations(memInst, [&](Annotation anno) -> bool {
// We're only looking for non-local annotations.
auto nlaSym = anno.getMember<FlatSymbolRefAttr>(nonlocalAttr);
if (!nlaSym)
continue;

return false;
// If we have already seen this NLA, don't re-process it.
auto newNLAIter = processedNLAs.find(nlaSym.getAttr());
StringAttr newNLAName;
Expand Down Expand Up @@ -327,7 +306,8 @@ void LowerMemoryPass::lowerMemory(MemOp mem, const FirMemory &summary,
anno.setMember("circt.nonlocal", FlatSymbolRefAttr::get(newNLAName));
nlaUpdated = true;
newMemModAnnos.push_back(anno);
}
return true;
});
if (nlaUpdated) {
memInst.setInnerSymAttr(hw::InnerSymAttr::get(leafSym));
AnnotationSet newAnnos(memInst);
Expand Down
18 changes: 13 additions & 5 deletions test/Dialect/FIRRTL/lower-memory.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,15 @@ firrtl.module @Dedup() {
%mem0_write = firrtl.mem Undefined {depth = 12 : i64, name = "mem0", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>>
%mem1_write = firrtl.mem Undefined {depth = 12 : i64, name = "mem1", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>>
// CHECK: firrtl.instance mem0 @mem0(
// CHECK: firrtl.instance mem1 @mem0(
// CHECK: firrtl.instance mem1 @mem1(
}
// CHECK: firrtl.module private @mem0
// CHECK-NEXT: firrtl.instance mem0_ext @mem0_ext

// CHECK: firrtl.memmodule private @mem0_ext

// CHECK: firrtl.module private @mem1
// CHECK-NEXT: firrtl.instance mem0_ext @mem0_ext
}

// Test that memories in the testharness are not deduped with other memories in
Expand Down Expand Up @@ -237,9 +240,9 @@ firrtl.module @Annotations() attributes {annotations = [{class = "sifive.enterpr
// CHECK-LABEL: firrtl.circuit "NonLocalAnnotation"
firrtl.circuit "NonLocalAnnotation" {

// CHECK: hw.hierpath private @[[nla0_0:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM0:.+]], @mem0]
// CHECK: hw.hierpath private @[[nla_0:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM0:.+]], @mem0]
hw.hierpath private @nla0 [@NonLocalAnnotation::@dut, @DUT::@mem0]
// CHECK: hw.hierpath private @[[nla1_0:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM1:.+]], @mem0]
// CHECK: hw.hierpath private @[[nla_1:.+]] [@NonLocalAnnotation::@dut, @DUT::@[[MEM1:.+]], @mem1]
hw.hierpath private @nla1 [@NonLocalAnnotation::@dut, @DUT]

// CHECK: firrtl.module @NonLocalAnnotation()
Expand All @@ -253,7 +256,7 @@ firrtl.module @DUT() {
%mem0_write = firrtl.mem sym @mem0 Undefined {annotations = [{circt.nonlocal = @nla0, class = "test0"}], depth = 12 : i64, name = "mem0", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>>

// This memory does not have a symbol already attached.
// CHECK: firrtl.instance mem1 sym @[[MEM1]] @mem0
// CHECK: firrtl.instance mem1 sym @[[MEM1]] @mem1
%mem1_write = firrtl.mem Undefined {annotations = [{circt.nonlocal = @nla1, class = "test1"}], depth = 12 : i64, name = "mem1", portNames = ["write"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<4>, en: uint<1>, clk: clock, data: uint<42>, mask: uint<1>>

// LowerMemory should ignore MemOps that are not seqmems. The following memory is a combmem with readLatency=0.
Expand All @@ -263,7 +266,12 @@ firrtl.module @DUT() {

// CHECK: firrtl.module private @mem0
// CHECK: firrtl.instance mem0_ext sym @mem0_ext
// CHECK-SAME: {annotations = [{circt.nonlocal = @[[nla0_0]], class = "test0"}, {circt.nonlocal = @[[nla1_0]], class = "test1"}]}
// CHECK-SAME: {annotations = [{circt.nonlocal = @[[nla_0]], class = "test0"}]}
// CHECK-SAME: @mem0_ext(
// CHECK: }

// CHECK: firrtl.module private @mem1
// CHECK: firrtl.instance mem0_ext sym @mem0_ext
// CHECK-SAME: {annotations = [{circt.nonlocal = @[[nla_1]], class = "test1"}]}
// CHECK-SAME: @mem0_ext(
}

0 comments on commit bea8510

Please sign in to comment.