From a25a583e9ffd7a37be458808c461433eb55373e8 Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Thu, 14 Mar 2024 16:14:49 -0400 Subject: [PATCH] [InferReadWrite] Add heuristic to infer unmasked memory (#6790) This PR updates the heuristic to infer an unmasked memory. If all the bits of the mask signal are driven by the same value, then it can be replaced with an unmasked memory. (Example: `mem_RW0_wmask = {6{baseWrEn_F1}}`) This is an attempt to fix a use case, in which firtool introduces masked memory for an aggregate data type when the user expected an unmasked one. --- .../FIRRTL/Transforms/InferReadWrite.cpp | 109 +++++++++++++++++- test/Dialect/FIRRTL/inferRW.mlir | 20 ++++ 2 files changed, 125 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp index 470caddf8a5c..03eaae0da3cd 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp @@ -314,6 +314,103 @@ struct InferReadWritePass : public InferReadWriteBase { return {}; } + void handleCatPrimOp(CatPrimOp defOp, SmallVectorImpl &bits) { + + long lastSize = 0; + // Cat the bits of both the operands. + for (auto operand : defOp->getOperands()) { + SmallVectorImpl &opBits = valueBitsSrc[operand]; + size_t s = + getBitWidth(type_cast(operand.getType())).value(); + assert(opBits.size() == s); + for (long i = lastSize, e = lastSize + s; i != e; ++i) + bits[i] = opBits[i - lastSize]; + lastSize = s; + } + } + + void handleBitsPrimOp(BitsPrimOp bitsPrim, SmallVectorImpl &bits) { + + SmallVectorImpl &opBits = valueBitsSrc[bitsPrim.getInput()]; + for (size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0; + srcIndex <= e; ++srcIndex, ++i) + bits[i] = opBits[srcIndex]; + } + + // Try to extract the value assigned to each bit of `val`. This is a heuristic + // to determine if each bit of the `val` is assigned the same value. + // Common pattern that this heuristic detects, + // mask = {{w1,w1},{w2,w2}}} + // w1 = w[0] + // w2 = w[0] + bool areBitsDrivenBySameSource(Value val) { + SmallVector stack; + stack.push_back(val); + + while (!stack.empty()) { + auto val = stack.back(); + if (valueBitsSrc.contains(val)) { + stack.pop_back(); + continue; + } + + auto size = getBitWidth(type_cast(val.getType())); + // Cannot analyze aggregate types. + if (!size.has_value()) + return false; + + auto bitsSize = size.value(); + if (auto *defOp = val.getDefiningOp()) { + if (isa(defOp)) { + bool operandsDone = true; + // If the value is a cat of other values, compute the bits of the + // operands. + for (auto operand : defOp->getOperands()) { + if (valueBitsSrc.contains(operand)) + continue; + stack.push_back(operand); + operandsDone = false; + } + if (!operandsDone) + continue; + + valueBitsSrc[val].resize_for_overwrite(bitsSize); + handleCatPrimOp(cast(defOp), valueBitsSrc[val]); + } else if (auto bitsPrim = dyn_cast(defOp)) { + auto input = bitsPrim.getInput(); + if (!valueBitsSrc.contains(input)) { + stack.push_back(input); + continue; + } + valueBitsSrc[val].resize_for_overwrite(bitsSize); + handleBitsPrimOp(bitsPrim, valueBitsSrc[val]); + } else if (auto constOp = dyn_cast(defOp)) { + auto constVal = constOp.getValue(); + valueBitsSrc[val].resize_for_overwrite(bitsSize); + if (constVal.isAllOnes() || constVal.isZero()) { + for (auto &b : valueBitsSrc[val]) + b = constOp; + } else + return false; + } else if (auto wireOp = dyn_cast(defOp)) { + if (bitsSize != 1) + return false; + valueBitsSrc[val].resize_for_overwrite(bitsSize); + if (auto src = getConnectSrc(wireOp.getResult())) { + valueBitsSrc[val][0] = src; + } else + valueBitsSrc[val][0] = wireOp.getResult(); + } else + return false; + } else + return false; + stack.pop_back(); + } + if (!valueBitsSrc.contains(val)) + return false; + return llvm::all_equal(valueBitsSrc[val]); + } + // Remove redundant dependence of wmode on the enable signal. wmode can assume // the enable signal be true. void simplifyWmode(MemOp &memOp) { @@ -404,11 +501,11 @@ struct InferReadWritePass : public InferReadWriteBase { if (sf.getResult().getType().getBitWidthOrSentinel() == 1) continue; // Check what is the mask field directly connected to. - // If, a constant 1, then we can replace with unMasked memory. + // If we can infer that all the bits of the mask are always assigned + // the same value, then the memory is unmasked. if (auto maskVal = getConnectSrc(sf)) - if (auto constVal = dyn_cast(maskVal.getDefiningOp())) - if (constVal.getValue().isAllOnes()) - isMasked = false; + if (areBitsDrivenBySameSource(maskVal)) + isMasked = false; } } } @@ -467,6 +564,10 @@ struct InferReadWritePass : public InferReadWriteBase { memOp = newMem; } } + + // Record of what are the source values that drive each bit of a value. Used + // to check if each bit of a value is being driven by the same source. + llvm::DenseMap> valueBitsSrc; }; } // end anonymous namespace diff --git a/test/Dialect/FIRRTL/inferRW.mlir b/test/Dialect/FIRRTL/inferRW.mlir index 1c3e959f74db..61a2d479e418 100644 --- a/test/Dialect/FIRRTL/inferRW.mlir +++ b/test/Dialect/FIRRTL/inferRW.mlir @@ -302,5 +302,25 @@ firrtl.circuit "TLRAM" { // CHECK: %[[v7:.+]] = firrtl.mux(%[[c1_ui1]], %rwPort_isWrite, %c0_ui1) firrtl.strictconnect %mem_rwPort_readData_rw_wmode, %18 : !firrtl.uint<1> } + + // CHECK: firrtl.module @InferUnmasked + firrtl.module @InferUnmasked(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>) attributes {convention = #firrtl} { + %readwritePortA_isWrite_2 = firrtl.wire {name = "readwritePortA_isWrite"} : !firrtl.uint<1> + %syncreadmem_singleport_readwritePortA_readData_rw = firrtl.mem Undefined {depth = 64 : i64, name = "syncreadmem_singleport", portNames = ["readwritePortA_readData_rw"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + // CHECK: %syncreadmem_singleport_readwritePortA_readData_rw = firrtl.mem Undefined {depth = 64 : i64, name = "syncreadmem_singleport", portNames = ["readwritePortA_readData_rw"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<1>> + %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x = firrtl.wire : !firrtl.uint<1> + %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y = firrtl.wire : !firrtl.uint<1> + %9 = firrtl.subfield %syncreadmem_singleport_readwritePortA_readData_rw[wmask] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + %10 = firrtl.cat %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y, %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> + %11 = firrtl.bits %10 0 to 0 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %12 = firrtl.cat %11, %11 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> + %13 = firrtl.cat %11, %12 : (!firrtl.uint<1>, !firrtl.uint<2>) -> !firrtl.uint<3> + %14 = firrtl.bits %10 1 to 1 : (!firrtl.uint<2>) -> !firrtl.uint<1> + %15 = firrtl.cat %14, %13 : (!firrtl.uint<1>, !firrtl.uint<3>) -> !firrtl.uint<4> + %16 = firrtl.cat %14, %15 : (!firrtl.uint<1>, !firrtl.uint<4>) -> !firrtl.uint<5> + firrtl.strictconnect %9, %16 : !firrtl.uint<5> + firrtl.connect %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x, %readwritePortA_isWrite_2 : !firrtl.uint<1>, !firrtl.uint<1> + firrtl.connect %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y, %readwritePortA_isWrite_2 : !firrtl.uint<1>, !firrtl.uint<1> + } }