Skip to content

Commit

Permalink
[InferReadWrite] Add heuristic to infer unmasked memory (#6790)
Browse files Browse the repository at this point in the history
This PR updates the heuristic to infer an unmasked memory.
If all the bits of the mask signal are driven by the same value, then it can
 be replaced with an unmasked memory. 
(Example:  `mem_RW0_wmask = {6{baseWrEn_F1}}`)
This is an attempt to fix a use case, in which firtool introduces masked memory
 for an aggregate data type when the user expected an unmasked one.
  • Loading branch information
prithayan committed Mar 14, 2024
1 parent 315892f commit a25a583
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
109 changes: 105 additions & 4 deletions lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,103 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
return {};
}

void handleCatPrimOp(CatPrimOp defOp, SmallVectorImpl<Value> &bits) {

long lastSize = 0;
// Cat the bits of both the operands.
for (auto operand : defOp->getOperands()) {
SmallVectorImpl<Value> &opBits = valueBitsSrc[operand];
size_t s =
getBitWidth(type_cast<FIRRTLBaseType>(operand.getType())).value();
assert(opBits.size() == s);
for (long i = lastSize, e = lastSize + s; i != e; ++i)
bits[i] = opBits[i - lastSize];
lastSize = s;
}
}

void handleBitsPrimOp(BitsPrimOp bitsPrim, SmallVectorImpl<Value> &bits) {

SmallVectorImpl<Value> &opBits = valueBitsSrc[bitsPrim.getInput()];
for (size_t srcIndex = bitsPrim.getLo(), e = bitsPrim.getHi(), i = 0;
srcIndex <= e; ++srcIndex, ++i)
bits[i] = opBits[srcIndex];
}

// Try to extract the value assigned to each bit of `val`. This is a heuristic
// to determine if each bit of the `val` is assigned the same value.
// Common pattern that this heuristic detects,
// mask = {{w1,w1},{w2,w2}}}
// w1 = w[0]
// w2 = w[0]
bool areBitsDrivenBySameSource(Value val) {
SmallVector<Value> stack;
stack.push_back(val);

while (!stack.empty()) {
auto val = stack.back();
if (valueBitsSrc.contains(val)) {
stack.pop_back();
continue;
}

auto size = getBitWidth(type_cast<FIRRTLBaseType>(val.getType()));
// Cannot analyze aggregate types.
if (!size.has_value())
return false;

auto bitsSize = size.value();
if (auto *defOp = val.getDefiningOp()) {
if (isa<CatPrimOp>(defOp)) {
bool operandsDone = true;
// If the value is a cat of other values, compute the bits of the
// operands.
for (auto operand : defOp->getOperands()) {
if (valueBitsSrc.contains(operand))
continue;
stack.push_back(operand);
operandsDone = false;
}
if (!operandsDone)
continue;

valueBitsSrc[val].resize_for_overwrite(bitsSize);
handleCatPrimOp(cast<CatPrimOp>(defOp), valueBitsSrc[val]);
} else if (auto bitsPrim = dyn_cast<BitsPrimOp>(defOp)) {
auto input = bitsPrim.getInput();
if (!valueBitsSrc.contains(input)) {
stack.push_back(input);
continue;
}
valueBitsSrc[val].resize_for_overwrite(bitsSize);
handleBitsPrimOp(bitsPrim, valueBitsSrc[val]);
} else if (auto constOp = dyn_cast<ConstantOp>(defOp)) {
auto constVal = constOp.getValue();
valueBitsSrc[val].resize_for_overwrite(bitsSize);
if (constVal.isAllOnes() || constVal.isZero()) {
for (auto &b : valueBitsSrc[val])
b = constOp;
} else
return false;
} else if (auto wireOp = dyn_cast<WireOp>(defOp)) {
if (bitsSize != 1)
return false;
valueBitsSrc[val].resize_for_overwrite(bitsSize);
if (auto src = getConnectSrc(wireOp.getResult())) {
valueBitsSrc[val][0] = src;
} else
valueBitsSrc[val][0] = wireOp.getResult();
} else
return false;
} else
return false;
stack.pop_back();
}
if (!valueBitsSrc.contains(val))
return false;
return llvm::all_equal(valueBitsSrc[val]);
}

// Remove redundant dependence of wmode on the enable signal. wmode can assume
// the enable signal be true.
void simplifyWmode(MemOp &memOp) {
Expand Down Expand Up @@ -404,11 +501,11 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
if (sf.getResult().getType().getBitWidthOrSentinel() == 1)
continue;
// Check what is the mask field directly connected to.
// If, a constant 1, then we can replace with unMasked memory.
// If we can infer that all the bits of the mask are always assigned
// the same value, then the memory is unmasked.
if (auto maskVal = getConnectSrc(sf))
if (auto constVal = dyn_cast<ConstantOp>(maskVal.getDefiningOp()))
if (constVal.getValue().isAllOnes())
isMasked = false;
if (areBitsDrivenBySameSource(maskVal))
isMasked = false;
}
}
}
Expand Down Expand Up @@ -467,6 +564,10 @@ struct InferReadWritePass : public InferReadWriteBase<InferReadWritePass> {
memOp = newMem;
}
}

// Record of what are the source values that drive each bit of a value. Used
// to check if each bit of a value is being driven by the same source.
llvm::DenseMap<Value, SmallVector<Value>> valueBitsSrc;
};
} // end anonymous namespace

Expand Down
20 changes: 20 additions & 0 deletions test/Dialect/FIRRTL/inferRW.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -302,5 +302,25 @@ firrtl.circuit "TLRAM" {
// CHECK: %[[v7:.+]] = firrtl.mux(%[[c1_ui1]], %rwPort_isWrite, %c0_ui1)
firrtl.strictconnect %mem_rwPort_readData_rw_wmode, %18 : !firrtl.uint<1>
}

// CHECK: firrtl.module @InferUnmasked
firrtl.module @InferUnmasked(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>) attributes {convention = #firrtl<convention scalarized>} {
%readwritePortA_isWrite_2 = firrtl.wire {name = "readwritePortA_isWrite"} : !firrtl.uint<1>
%syncreadmem_singleport_readwritePortA_readData_rw = firrtl.mem Undefined {depth = 64 : i64, name = "syncreadmem_singleport", portNames = ["readwritePortA_readData_rw"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<6>, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>>
// CHECK: %syncreadmem_singleport_readwritePortA_readData_rw = firrtl.mem Undefined {depth = 64 : i64, name = "syncreadmem_singleport", portNames = ["readwritePortA_readData_rw"], readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle<addr: uint<6>, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<1>>
%syncreadmem_singleport_readwritePortA_readData_rw_wmask_x = firrtl.wire : !firrtl.uint<1>
%syncreadmem_singleport_readwritePortA_readData_rw_wmask_y = firrtl.wire : !firrtl.uint<1>
%9 = firrtl.subfield %syncreadmem_singleport_readwritePortA_readData_rw[wmask] : !firrtl.bundle<addr: uint<6>, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>>
%10 = firrtl.cat %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y, %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2>
%11 = firrtl.bits %10 0 to 0 : (!firrtl.uint<2>) -> !firrtl.uint<1>
%12 = firrtl.cat %11, %11 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2>
%13 = firrtl.cat %11, %12 : (!firrtl.uint<1>, !firrtl.uint<2>) -> !firrtl.uint<3>
%14 = firrtl.bits %10 1 to 1 : (!firrtl.uint<2>) -> !firrtl.uint<1>
%15 = firrtl.cat %14, %13 : (!firrtl.uint<1>, !firrtl.uint<3>) -> !firrtl.uint<4>
%16 = firrtl.cat %14, %15 : (!firrtl.uint<1>, !firrtl.uint<4>) -> !firrtl.uint<5>
firrtl.strictconnect %9, %16 : !firrtl.uint<5>
firrtl.connect %syncreadmem_singleport_readwritePortA_readData_rw_wmask_x, %readwritePortA_isWrite_2 : !firrtl.uint<1>, !firrtl.uint<1>
firrtl.connect %syncreadmem_singleport_readwritePortA_readData_rw_wmask_y, %readwritePortA_isWrite_2 : !firrtl.uint<1>, !firrtl.uint<1>
}
}

0 comments on commit a25a583

Please sign in to comment.