From 7bc4de53b35dc890d45a83ca18b84b69e94efd89 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Tue, 12 Mar 2024 20:10:00 -0700 Subject: [PATCH] [FIRRTL] Support alternative base paths in LowerClasses. We previously had a constraint that paths are targeting entities in the same owning module as the path. This allowed us to assume that a single base path can be created for each owning module, and used in all paths. However, we have transforms that extract entities targeted by paths outside the paths' owning module. To support this, we can instead plumb through base paths created higher up in the instance graph down to the paths that need them. When a path is detected outside the owning module, we find the module in which the entity is instantiated and pass through its base path to the paths that reference the entity. Most of the work here is just plumbing, and the PathInfoTable has been extended with new data structures and functions to support this. --- .../FIRRTL/Transforms/LowerClasses.cpp | 252 +++++++++++++++--- test/Dialect/FIRRTL/lower-classes.mlir | 34 +++ 2 files changed, 255 insertions(+), 31 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp index f9d56c6f0d1f..e0fb114eb3ff 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/Threading.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" @@ -39,7 +40,9 @@ namespace { /// to the targeted operation. struct PathInfo { PathInfo() = default; - PathInfo(Operation *op, FlatSymbolRefAttr symRef) : op(op), symRef(symRef) { + PathInfo(Operation *op, FlatSymbolRefAttr symRef, + StringAttr altBasePathModule) + : op(op), symRef(symRef), altBasePathModule(altBasePathModule) { assert(op && "op must not be null"); assert(symRef && "symRef must not be null"); } @@ -51,12 +54,82 @@ struct PathInfo { /// A reference to the hierarchical path targeting the op. FlatSymbolRefAttr symRef = nullptr; + + /// The module name of the root module from which we take an alternative base + /// path. + StringAttr altBasePathModule = nullptr; }; /// Maps a FIRRTL path id to the lowered PathInfo. struct PathInfoTable { + // Add an alternative base path root module. The default base path from this + // module with be passed through to where it is needed. + void addAltBasePathRoot(StringAttr rootModuleName) { + altBasePathRoots.insert(rootModuleName); + } + + // Add a passthrough module for a given root module. The default base path + // from the root module will be passed through the passthrough module. + void addAltBasePathPassthrough(StringAttr passthroughModuleName, + StringAttr rootModuleName) { + auto &rootSequence = altBasePathsPassthroughs[passthroughModuleName]; + rootSequence.push_back(rootModuleName); + } + + // Get an iterator range over the alternative base path root module names. + llvm::iterator_range::iterator> + getAltBasePathRoots() const { + return llvm::make_range(altBasePathRoots.begin(), altBasePathRoots.end()); + } + + // Get the number of alternative base paths passing through the given + // passthrough module. + size_t getNumAltBasePaths(StringAttr passthroughModuleName) const { + return altBasePathsPassthroughs.lookup(passthroughModuleName).size(); + } + + // Get the root modules that are passing an alternative base path through the + // given passthrough module. + llvm::iterator_range::iterator> + getRootsForPassthrough(StringAttr passthroughModuleName) const { + auto rootSequence = altBasePathsPassthroughs.lookup(passthroughModuleName); + return llvm::make_range(rootSequence.begin(), rootSequence.end()); + } + + // Collect alternative base paths passing through `instance`, by looking up + // its associated `moduleNameAttr`. The results are colelcted in `result`. + void collectAltBasePaths(Operation *instance, StringAttr moduleNameAttr, + SmallVectorImpl &result) const { + auto altBasePaths = altBasePathsPassthroughs.lookup(moduleNameAttr); + auto parent = instance->getParentOfType(); + + // Handle each alternative base path for instances of this module-like. + for (auto [i, altBasePath] : llvm::enumerate(altBasePaths)) { + if (parent.getName().starts_with(altBasePath)) { + // If we are passing down from the root, take the root base path. + result.push_back(instance->getBlock()->getArgument(0)); + } else { + // Otherwise, pass through the appropriate base path from above. + // + 1 to skip default base path + auto basePath = instance->getBlock()->getArgument(1 + i); + assert(isa(basePath) && + "expected a passthrough base path"); + result.push_back(basePath); + } + } + } + // The table mapping DistinctAttrs to PathInfo structs. DenseMap table; + +private: + // Module name attributes indicating modules whose base path input should + // be used as alternate base paths. + SmallPtrSet altBasePathRoots; + + // Module name attributes mapping from modules who pass through alternative + // base paths from their parents to a sequence of the parents' module names. + DenseMap> altBasePathsPassthroughs; }; /// The suffix to append to lowered module names. @@ -94,16 +167,20 @@ struct LowerClassesPass : public LowerClassesBase { bool shouldCreateClass(FModuleLike moduleLike); // Create an OM Class op from a FIRRTL Class op. - om::ClassLike createClass(FModuleLike moduleLike); + om::ClassLike createClass(FModuleLike moduleLike, + const PathInfoTable &pathInfoTable); // Lower the FIRRTL Class to OM Class. - void lowerClassLike(FModuleLike moduleLike, om::ClassLike classLike); - void lowerClass(om::ClassOp classOp, FModuleLike moduleLike); + void lowerClassLike(FModuleLike moduleLike, om::ClassLike classLike, + const PathInfoTable &pathInfoTable); + void lowerClass(om::ClassOp classOp, FModuleLike moduleLike, + const PathInfoTable &pathInfoTable); void lowerClassExtern(ClassExternOp classExternOp, FModuleLike moduleLike); // Update Object instantiations in a FIRRTL Module or OM Class. LogicalResult updateInstances(Operation *op, InstanceGraph &instanceGraph, - const LoweringState &state); + const LoweringState &state, + const PathInfoTable &pathInfoTable); // Convert to OM ops and types in Classes or Modules. LogicalResult dialectConversion( @@ -269,18 +346,19 @@ LogicalResult LowerClassesPass::processPaths( // Copy the leading part of the hierarchical path from the owning module // to the start of the annotation's NLA. + bool needsAltBasePath = false; auto *node = instanceGraph.lookup(moduleName); while (true) { // If the path is rooted at the owning module, we're done. if (node->getModule() == owningModule) break; // If there are no more parents, then the path op lives in a different - // hierarchy than the HW object it references, which is an error. + // hierarchy than the HW object it references, which needs to handled + // specially. Flag this, so we know to create an alternative base path + // below. if (node->noUses()) { - op->emitError() << "unable to resolve path relative to owning module " - << owningModule.getModuleNameAttr(); - error = true; - return false; + needsAltBasePath = true; + break; } // If there is more than one instance of this module, then the path // operation is ambiguous, which is an error. @@ -300,8 +378,16 @@ LogicalResult LowerClassesPass::processPaths( std::reverse(path.begin(), path.end()); auto pathAttr = ArrayAttr::get(context, path); + // If we need an alternative base path, save the top module from the path. + // We will plumb in the basepath from this module. + StringAttr altBasePathModule; + if (needsAltBasePath) { + altBasePathModule = cast(path.front()).getModule(); + pathInfoTable.addAltBasePathRoot(altBasePathModule); + } + // Record the path operation associated with the path op. - pathInfo = {op, cache.getRefFor(pathAttr)}; + pathInfo = {op, cache.getRefFor(pathAttr), altBasePathModule}; // Remove this annotation from the operation. return true; @@ -330,6 +416,44 @@ LogicalResult LowerClassesPass::processPaths( if (result.wasInterrupted()) return failure(); } + + // For each module that will be passing through a base path, compute its + // descendants that need this base path passed through. + for (auto rootModule : pathInfoTable.getAltBasePathRoots()) { + InstanceGraphNode *node = instanceGraph.lookup(rootModule); + + // Do a depth first traversal of the instance graph from rootModule, looking + // for descendants that need to be passed through. + auto start = llvm::df_begin(node); + auto end = llvm::df_end(node); + auto it = start; + while (it != end) { + // Nothing to do for the root module. + if (it == start) { + ++it; + continue; + } + + // If we aren't creating a class for this child, skip this hierarchy. + if (!shouldCreateClass(it->getModule())) { + it = it.skipChildren(); + continue; + } + + // If we are at a leaf, nothing to do. + if (std::distance(it->begin(), it->end()) == 0) { + ++it; + continue; + } + + // Track state for this passthrough. + StringAttr passthroughModule = it->getModule().getModuleNameAttr(); + pathInfoTable.addAltBasePathPassthrough(passthroughModule, rootModule); + + ++it; + } + } + return success(); } @@ -361,7 +485,7 @@ void LowerClassesPass::runOnOperation() { DenseMap classTypeTable; for (auto moduleLike : circuit.getOps()) { if (shouldCreateClass(moduleLike)) { - auto omClass = createClass(moduleLike); + auto omClass = createClass(moduleLike, pathInfoTable); auto &classLoweringState = loweringState.classLoweringStateTable[omClass]; classLoweringState.moduleLike = moduleLike; @@ -394,9 +518,10 @@ void LowerClassesPass::runOnOperation() { // Move ops from FIRRTL Class to OM Class in parallel. mlir::parallelForEach(ctx, loweringState.classLoweringStateTable, - [this](auto &entry) { + [this, &pathInfoTable](auto &entry) { const auto &[classLike, state] = entry; - lowerClassLike(state.moduleLike, classLike); + lowerClassLike(state.moduleLike, classLike, + pathInfoTable); }); // Completely erase Class module-likes, and remove from the InstanceGraph. @@ -419,7 +544,8 @@ void LowerClassesPass::runOnOperation() { // Update Object creation ops in Classes or Modules in parallel. if (failed( mlir::failableParallelForEach(ctx, objectContainers, [&](auto *op) { - return updateInstances(op, instanceGraph, loweringState); + return updateInstances(op, instanceGraph, loweringState, + pathInfoTable); }))) return signalPassFailure(); @@ -483,11 +609,21 @@ bool LowerClassesPass::shouldCreateClass(FModuleLike moduleLike) { } // Create an OM Class op from a FIRRTL Class op or Module op with properties. -om::ClassLike LowerClassesPass::createClass(FModuleLike moduleLike) { +om::ClassLike +LowerClassesPass::createClass(FModuleLike moduleLike, + const PathInfoTable &pathInfoTable) { // Collect the parameter names from input properties. SmallVector formalParamNames; // Every class gets a base path as its first parameter. formalParamNames.emplace_back("basepath"); + + // If this class is passing through base paths from above, add those. + size_t nAltBasePaths = + pathInfoTable.getNumAltBasePaths(moduleLike.getModuleNameAttr()); + for (size_t i = 0; i < nAltBasePaths; ++i) + formalParamNames.push_back(StringAttr::get( + moduleLike->getContext(), "alt_basepath_" + llvm::Twine(i))); + for (auto [index, port] : llvm::enumerate(moduleLike.getPorts())) if (port.isInput() && isa(port.type)) formalParamNames.push_back(port.name); @@ -520,10 +656,11 @@ om::ClassLike LowerClassesPass::createClass(FModuleLike moduleLike) { } void LowerClassesPass::lowerClassLike(FModuleLike moduleLike, - om::ClassLike classLike) { + om::ClassLike classLike, + const PathInfoTable &pathInfoTable) { if (auto classOp = dyn_cast(classLike.getOperation())) { - return lowerClass(classOp, moduleLike); + return lowerClass(classOp, moduleLike, pathInfoTable); } if (auto classExternOp = dyn_cast(classLike.getOperation())) { @@ -532,7 +669,8 @@ void LowerClassesPass::lowerClassLike(FModuleLike moduleLike, llvm_unreachable("unhandled class-like op"); } -void LowerClassesPass::lowerClass(om::ClassOp classOp, FModuleLike moduleLike) { +void LowerClassesPass::lowerClass(om::ClassOp classOp, FModuleLike moduleLike, + const PathInfoTable &pathInfoTable) { // Map from Values in the FIRRTL Class to Values in the OM Class. IRMapping mapping; @@ -556,8 +694,16 @@ void LowerClassesPass::lowerClass(om::ClassOp classOp, FModuleLike moduleLike) { // updating the mapping to map from the input property to the block argument. Block *classBody = &classOp->getRegion(0).emplaceBlock(); // Every class created from a module gets a base path as its first parameter. - classBody->addArgument(BasePathType::get(&getContext()), - UnknownLoc::get(&getContext())); + auto basePathType = BasePathType::get(&getContext()); + auto unknownLoc = UnknownLoc::get(&getContext()); + classBody->addArgument(basePathType, unknownLoc); + + // If this class is passing through base paths from above, add those. + size_t nAltBasePaths = + pathInfoTable.getNumAltBasePaths(moduleLike.getModuleNameAttr()); + for (size_t i = 0; i < nAltBasePaths; ++i) + classBody->addArgument(basePathType, unknownLoc); + for (auto inputProperty : inputProperties) { BlockArgument parameterValue = classBody->addArgument(inputProperty.type, inputProperty.loc); @@ -661,6 +807,7 @@ void LowerClassesPass::lowerClassExtern(ClassExternOp classExternOp, // converted to OM Object instances. static LogicalResult updateObjectInClass(firrtl::ObjectOp firrtlObject, + const PathInfoTable &pathInfoTable, SmallVectorImpl &opsToErase) { // The 0'th argument is the base path. auto basePath = firrtlObject->getBlock()->getArgument(0); @@ -671,7 +818,13 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, llvm::SmallVector argIndexTable; argIndexTable.resize(numElements); - unsigned nextArgIndex = 1; + // Get any alternative base paths passing through this module. + SmallVector altBasePaths; + pathInfoTable.collectAltBasePaths( + firrtlObject, firrtlClassType.getNameAttr().getAttr(), altBasePaths); + + // Account for the default base path and any alternatives. + unsigned nextArgIndex = 1 + altBasePaths.size(); for (unsigned i = 0; i < numElements; ++i) { auto direction = firrtlClassType.getElement(i).direction; @@ -686,6 +839,10 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, args.resize(nextArgIndex); args[0] = basePath; + // Collect any alternative base paths passing through. + for (auto [i, altBasePath] : llvm::enumerate(altBasePaths)) + args[1 + i] = altBasePath; // + 1 to skip default base path + for (auto *user : llvm::make_early_inc_range(firrtlObject->getUsers())) { if (auto subfield = dyn_cast(user)) { auto index = subfield.getIndex(); @@ -751,6 +908,7 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, static LogicalResult updateInstanceInClass(InstanceOp firrtlInstance, hw::HierPathOp hierPath, InstanceGraph &instanceGraph, + const PathInfoTable &pathInfoTable, SmallVectorImpl &opsToErase) { // Set the insertion point right before the instance op. @@ -767,6 +925,12 @@ updateInstanceInClass(InstanceOp firrtlInstance, hw::HierPathOp hierPath, firrtlInstance->getLoc(), basePath, symRef); actualParameters.push_back(rebasedPath); + + // Add any alternative base paths passing through this instance. + pathInfoTable.collectAltBasePaths( + firrtlInstance, firrtlInstance.getModuleNameAttr().getAttr(), + actualParameters); + for (auto result : firrtlInstance.getResults()) { // If the port is an output, continue. if (firrtlInstance.getPortDirection(result.getResultNumber()) == @@ -886,17 +1050,18 @@ updateInstancesInModule(FModuleOp moduleOp, InstanceGraph &instanceGraph, static LogicalResult updateObjectsAndInstancesInClass( om::ClassOp classOp, InstanceGraph &instanceGraph, - const LoweringState &state, SmallVectorImpl &opsToErase) { + const LoweringState &state, const PathInfoTable &pathInfoTable, + SmallVectorImpl &opsToErase) { OpBuilder builder(classOp); auto &classState = state.classLoweringStateTable.at(classOp); auto it = classState.paths.begin(); for (auto &op : classOp->getRegion(0).getOps()) { if (auto objectOp = dyn_cast(op)) { - if (failed(updateObjectInClass(objectOp, opsToErase))) + if (failed(updateObjectInClass(objectOp, pathInfoTable, opsToErase))) return failure(); } else if (auto instanceOp = dyn_cast(op)) { if (failed(updateInstanceInClass(instanceOp, *it++, instanceGraph, - opsToErase))) + pathInfoTable, opsToErase))) return failure(); } } @@ -904,9 +1069,10 @@ static LogicalResult updateObjectsAndInstancesInClass( } // Update Object or Module instantiations in a FIRRTL Module or OM Class. -LogicalResult LowerClassesPass::updateInstances(Operation *op, - InstanceGraph &instanceGraph, - const LoweringState &state) { +LogicalResult +LowerClassesPass::updateInstances(Operation *op, InstanceGraph &instanceGraph, + const LoweringState &state, + const PathInfoTable &pathInfoTable) { // Track ops to erase at the end. We can't do this eagerly, since we want to // loop over each op in the container's body, and we may end up removing some @@ -923,8 +1089,8 @@ LogicalResult LowerClassesPass::updateInstances(Operation *op, .Case([&](om::ClassOp classOp) { // Convert FIRRTL Module instance within a Class to OM // Object instance. - return updateObjectsAndInstancesInClass(classOp, instanceGraph, - state, opsToErase); + return updateObjectsAndInstancesInClass( + classOp, instanceGraph, state, pathInfoTable, opsToErase); }) .Default([](auto *op) { return success(); }); if (failed(result)) @@ -1060,7 +1226,7 @@ struct PathOpConversion : public OpConversionPattern { auto pathType = om::PathType::get(context); auto pathInfo = pathInfoTable.table.lookup(op.getTarget()); - // The 0'th argument is the base path. + // The 0'th argument is the base path by default. auto basePath = op->getBlock()->getArgument(0); // If the target was optimized away, then replace the path operation with @@ -1102,6 +1268,30 @@ struct PathOpConversion : public OpConversionPattern { break; } + // If we are using an alternative base path for this path, get it from the + // passthrough port on the enclosing class. + if (auto altBasePathModule = pathInfo.altBasePathModule) { + // Get the base paths passing through the parent. + auto parent = op->getParentOfType(); + auto originalParentName = + StringAttr::get(op->getContext(), + parent.getName().drop_back(kClassNameSuffix.size())); + auto altBasePaths = + pathInfoTable.getRootsForPassthrough(originalParentName); + assert(!altBasePaths.empty() && "expected passthrough base paths"); + + // Find the base path passthrough that was associated with this path. + for (auto [i, altBasePath] : llvm::enumerate(altBasePaths)) { + if (altBasePathModule == altBasePath) { + // + 1 to skip default base path + auto basePathArg = op->getBlock()->getArgument(1 + i); + assert(isa(basePathArg) && + "expected a passthrough base path"); + basePath = basePathArg; + } + } + } + rewriter.replaceOpWithNewOp( op, pathType, om::TargetKindAttr::get(op.getContext(), targetKind), basePath, symbol); diff --git a/test/Dialect/FIRRTL/lower-classes.mlir b/test/Dialect/FIRRTL/lower-classes.mlir index 2b270121b354..b85cdbae8f1a 100644 --- a/test/Dialect/FIRRTL/lower-classes.mlir +++ b/test/Dialect/FIRRTL/lower-classes.mlir @@ -374,3 +374,37 @@ firrtl.circuit "IntegerArithmetic" { %4 = firrtl.integer.shr %0, %1 : (!firrtl.integer, !firrtl.integer) -> !firrtl.integer } } + +// CHECK-LABEL: firrtl.circuit "AltBasePath" +firrtl.circuit "AltBasePath" { + firrtl.class private @Node(in %path: !firrtl.path) { + } + + // CHECK: om.class @OMIR(%basepath: !om.basepath, %alt_basepath_0: !om.basepath) + firrtl.class private @OMIR() { + %node = firrtl.object @Node(in path: !firrtl.path) + %0 = firrtl.object.subfield %node[path] : !firrtl.class<@Node(in path: !firrtl.path)> + + // CHECK: om.path_create member_instance %alt_basepath_0 + %1 = firrtl.path member_reference distinct[0]<> + firrtl.propassign %0, %1 : !firrtl.path + } + + // CHECK: om.class @DUT_Class(%basepath: !om.basepath, %alt_basepath_0: !om.basepath) + firrtl.module @DUT(out %omirOut: !firrtl.class<@OMIR()>) attributes {convention = #firrtl} { + // CHECK: om.object @OMIR(%basepath, %alt_basepath_0) + %omir = firrtl.object @OMIR() + firrtl.propassign %omirOut, %omir : !firrtl.class<@OMIR()> + } + + // CHECK: om.class @AltBasePath_Class(%basepath: !om.basepath) + firrtl.module @AltBasePath() attributes {convention = #firrtl} { + // CHECK: om.object @DUT_Class(%0, %basepath) + %dut_omirOut = firrtl.instance dut interesting_name @DUT(out omirOut: !firrtl.class<@OMIR()>) + firrtl.instance foo interesting_name {annotations = [{class = "circt.tracker", id = distinct[0]<>}]} @Foo() + } + + firrtl.module private @Foo() attributes {annotations = [{class = "circt.tracker", id = distinct[1]<>}]} { + firrtl.skip + } +}