Skip to content

Commit

Permalink
[Seq][Arc] Allow seq.initial to take immutable operands. Add cast ope…
Browse files Browse the repository at this point in the history
…rations

This commit adds support for seq.initial ops to take immutable operands and
introduces new cast operations:

- Update InitialOp to accept input operands of ImmutableType
- Add FromImmutableOp and ToImmutableOp for casting between immutable and regular types
- Add mergeInitialOps helper to handle operands and topological sorting
  of initial ops when lowering (SV flow -> SeqToSV, Arc flow ->
  LowerState).
- Update lowering passes and dialects to work with new initial op
  • Loading branch information
uenoku committed Sep 26, 2024
1 parent 337dad5 commit 79c5799
Show file tree
Hide file tree
Showing 21 changed files with 373 additions and 91 deletions.
6 changes: 6 additions & 0 deletions include/circt/Dialect/Seq/SeqOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ createConstantInitialValue(OpBuilder builder, Operation *constantLike);
// initial op.
Value unwrapImmutableValue(mlir::TypedValue<seq::ImmutableType> immutableVal);

// Helper function to merge initial ops within the block into a single initial
// op. Return failure if we cannot topologically sort the initial ops.
// Return null if there is no initial op in the block. Return the initial op
// otherwise.
FailureOr<seq::InitialOp> mergeInitialOps(Block *block);

} // namespace seq
} // namespace circt

Expand Down
26 changes: 24 additions & 2 deletions include/circt/Dialect/Seq/SeqOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def InitialOp : SeqOp<"initial", [SingleBlock,
See the Seq dialect rationale for a longer description.
}];

let arguments = (ins);
let arguments = (ins Variadic<ImmutableType>: $inputs);
let results = (outs Variadic<ImmutableType>); // seq.immutable values
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
Expand All @@ -721,7 +721,7 @@ def InitialOp : SeqOp<"initial", [SingleBlock,
];

let assemblyFormat = [{
$body attr-dict `:` type(results)
`(` $inputs `)` $body attr-dict `:` functional-type($inputs, results)
}];

let extraClassDeclaration = [{
Expand All @@ -740,3 +740,25 @@ def YieldOp : SeqOp<"yield",

let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}

def FromImmutableOp : SeqOp<"from_immutable", [Pure]> {
let summary = "Cast from an immutable type to a wire type";

let arguments = (ins ImmutableType:$input);
let results = (outs AnyType:$output);

let assemblyFormat = "$input attr-dict `:` functional-type(operands, results)";
}

def ToImmutableOp : SeqOp<"to_immutable", [Pure]> {
let summary = "Cast from a wire type to an immutable type.";
let description = [{
This is an unsafe cast op that converts a HW type to an immutable type. It assumes
the value is valid at initialization phase.
}];

let arguments = (ins AnyType:$input);
let results = (outs ImmutableType:$output);

let assemblyFormat = "$input attr-dict `:` functional-type(operands, results)";
}
4 changes: 2 additions & 2 deletions integration_test/Bindings/Python/dialects/seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def top(module):
poweron_value = hw.ConstantOp.create(i32, 42).result
# CHECK: %[[INPUT_VAL:.+]] = hw.constant 45
reg_input = hw.ConstantOp.create(i32, 45).result
# CHECK-NEXT: %[[POWERON_VAL:.+]] = seq.initial {
# CHECK-NEXT: %[[POWERON_VAL:.+]] = seq.initial() {
# CHECK-NEXT: %[[C42:.+]] = hw.constant 42 : i32
# CHECK-NEXT: seq.yield %[[C42]] : i32
# CHECK-NEXT: } : !seq.immutable<i32>
# CHECK-NEXT: } : () -> !seq.immutable<i32>
# CHECK: %[[DATA_VAL:.+]] = seq.compreg %[[INPUT_VAL]], %clk reset %rst, %[[RESET_VAL]] initial %[[POWERON_VAL]]
reg = seq.CompRegOp(i32,
reg_input,
Expand Down
4 changes: 2 additions & 2 deletions integration_test/arcilator/JIT/reg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ hw.module @counter(in %clk: i1, out o1: i8, out o2: i8) {

%r0 = seq.compreg %added1, %seq_clk initial %0#0 : i8
%r1 = seq.compreg %added2, %seq_clk initial %0#1 : i8
%0:2 = seq.initial {
%0:2 = seq.initial () {
%1 = func.call @random() : () -> i32
%2 = comb.extract %1 from 0 : (i32) -> i8
%3 = hw.constant 5 : i8
seq.yield %2, %3: i8, i8
} : !seq.immutable<i8>, !seq.immutable<i8>
} : () -> (!seq.immutable<i8>, !seq.immutable<i8>)

%one = hw.constant 1 : i8
%added1 = comb.add %r0, %one : i8
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/dialects/seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self,
if power_on_value.owner is None:
assert False, "Initial value must not be port"
elif isinstance(power_on_value.owner.opview, hw.ConstantOp):
init = InitialOp([seq.ImmutableType.get(power_on_value.type)])
init = InitialOp([seq.ImmutableType.get(power_on_value.type)], [])
init.body.blocks.append()
with InsertionPoint(init.body.blocks[0]):
cloned_constant = power_on_value.owner.clone()
Expand Down
9 changes: 4 additions & 5 deletions lib/Conversion/ConvertToArcs/ConvertToArcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ static LogicalResult convertInitialValue(seq::CompRegOp reg,
if (!reg.getInitialValue())
return values.push_back({}), success();

// unrealized_conversion_cast to normal type
// Use from_immutable cast to convert the seq.immutable type to the reg's
// type.
OpBuilder builder(reg);
auto init = builder
.create<mlir::UnrealizedConversionCastOp>(
reg.getLoc(), reg.getType(), reg.getInitialValue())
.getResult(0);
auto init = builder.create<seq::FromImmutableOp>(reg.getLoc(), reg.getType(),
reg.getInitialValue());

values.push_back(init);
return success();
Expand Down
107 changes: 86 additions & 21 deletions lib/Conversion/SeqToSV/SeqToSV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "circt/Dialect/SV/SVOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Support/Naming.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
Expand Down Expand Up @@ -59,11 +60,10 @@ struct SeqToSVPass : public impl::LowerSeqToSVBase<SeqToSVPass> {
namespace {
struct ModuleLoweringState {
ModuleLoweringState(HWModuleOp module)
: initalOpLowering(module), module(module) {}
: immutableValueLowering(module), module(module) {}

struct InitialOpLowering {
InitialOpLowering(hw::HWModuleOp module)
: builder(module.getModuleBody()), module(module) {}
struct ImmutableValueLowering {
ImmutableValueLowering(hw::HWModuleOp module) : module(module) {}

// Lower initial ops.
LogicalResult lower();
Expand All @@ -82,9 +82,8 @@ struct ModuleLoweringState {
// defined in SV initial op.
MapVector<mlir::TypedValue<seq::ImmutableType>, Value> mapping;

OpBuilder builder;
hw::HWModuleOp module;
} initalOpLowering;
} immutableValueLowering;

struct FragmentInfo {
bool needsRegFragment = false;
Expand All @@ -94,21 +93,39 @@ struct ModuleLoweringState {
HWModuleOp module;
};

LogicalResult ModuleLoweringState::InitialOpLowering::lower() {
auto loweringFailed = module
.walk([&](seq::InitialOp initialOp) {
if (failed(lower(initialOp)))
return mlir::WalkResult::interrupt();
return mlir::WalkResult::advance();
})
.wasInterrupted();
return LogicalResult::failure(loweringFailed);
LogicalResult ModuleLoweringState::ImmutableValueLowering::lower() {
auto result = mergeInitialOps(module.getBodyBlock());
if (failed(result))
return failure();

auto initialOp = *result;
if (!initialOp)
return success();

return lower(initialOp);
}

LogicalResult
ModuleLoweringState::InitialOpLowering::lower(seq::InitialOp initialOp) {
ModuleLoweringState::ImmutableValueLowering::lower(seq::InitialOp initialOp) {
OpBuilder builder = OpBuilder::atBlockBegin(module.getBodyBlock());
if (!svInitialOp)
svInitialOp = builder.create<sv::InitialOp>(initialOp->getLoc());
// Replace immutable operands passed to initial op with already lowered
// values.
for (auto [blockArgument, operand] :
llvm::zip(initialOp.getBodyBlock()->getArguments(),
initialOp->getOpOperands())) {

auto immut = operand.get().getDefiningOp<seq::ToImmutableOp>();
if (!immut)
return initialOp.emitError()
<< "invalid operand to initial op: " << operand.get();
blockArgument.replaceAllUsesWith(immut.getInput());
operand.drop();
if (immut.use_empty())
immut.erase();
}

auto loc = initialOp.getLoc();
llvm::SmallVector<Value> results;

Expand All @@ -127,10 +144,10 @@ ModuleLoweringState::InitialOpLowering::lower(seq::InitialOp initialOp) {
}

svInitialOp.getBodyBlock()->getOperations().splice(
svInitialOp.begin(), initialOp.getBodyBlock()->getOperations());
svInitialOp.end(), initialOp.getBodyBlock()->getOperations());

assert(initialOp->use_empty());
initialOp->erase();
initialOp.erase();
yieldOp->erase();
return success();
}
Expand Down Expand Up @@ -200,7 +217,7 @@ class CompRegLower : public OpConversionPattern<OpTy> {
auto module = reg->template getParentOfType<hw::HWModuleOp>();
const auto &initial =
moduleLoweringStates.find(module.getModuleNameAttr())
->second.initalOpLowering;
->second.immutableValueLowering;

Value initialValue = initial.lookupImmutableValue(init);

Expand Down Expand Up @@ -247,6 +264,49 @@ void CompRegLower<CompRegClockEnabledOp>::createAssign(
});
}

/// Lower FromImmutable to `sv.reg` and `sv.initial`.
class FromImmutableLowering : public OpConversionPattern<FromImmutableOp> {
public:
FromImmutableLowering(
TypeConverter &typeConverter, MLIRContext *context,
const MapVector<StringAttr, ModuleLoweringState> &moduleLoweringStates)
: OpConversionPattern<FromImmutableOp>(typeConverter, context),
moduleLoweringStates(moduleLoweringStates) {}

using OpAdaptor = typename OpConversionPattern<FromImmutableOp>::OpAdaptor;

LogicalResult
matchAndRewrite(FromImmutableOp fromImmutableOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = fromImmutableOp.getLoc();

auto regTy = ConversionPattern::getTypeConverter()->convertType(
fromImmutableOp.getType());
auto svReg = rewriter.create<sv::RegOp>(loc, regTy);

auto regVal = rewriter.create<sv::ReadInOutOp>(loc, svReg);

// Lower initial values.
auto module = fromImmutableOp->template getParentOfType<hw::HWModuleOp>();
const auto &initial = moduleLoweringStates.find(module.getModuleNameAttr())
->second.immutableValueLowering;

Value initialValue =
initial.lookupImmutableValue(fromImmutableOp.getInput());

OpBuilder::InsertionGuard guard(rewriter);
auto in = initial.getSVInitial();
rewriter.setInsertionPointToEnd(in.getBodyBlock());
rewriter.create<sv::BPAssignOp>(fromImmutableOp->getLoc(), svReg,
initialValue);

rewriter.replaceOp(fromImmutableOp, regVal);
return success();
}

private:
const MapVector<StringAttr, ModuleLoweringState> &moduleLoweringStates;
};
// Lower seq.clock_gate to a fairly standard clock gate implementation.
//
class ClockGateLowering : public OpConversionPattern<ClockGateOp> {
Expand Down Expand Up @@ -537,7 +597,7 @@ void SeqToSVPass::runOnOperation() {
moduleLoweringStates.try_emplace(module.getModuleNameAttr(),
ModuleLoweringState(module));

mlir::parallelForEach(
auto result = mlir::failableParallelForEach(
&getContext(), moduleLoweringStates, [&](auto &moduleAndState) {
auto &state = moduleAndState.second;
auto module = state.module;
Expand All @@ -561,9 +621,12 @@ void SeqToSVPass::runOnOperation() {
}
needsMemRandomization = true;
}
(void)state.initalOpLowering.lower();
return state.immutableValueLowering.lower();
});

if (failed(result))
return signalPassFailure();

auto randomInitFragmentName =
FlatSymbolRefAttr::get(context, "RANDOM_INIT_FRAGMENT");
auto randomInitRegFragmentName =
Expand Down Expand Up @@ -605,6 +668,8 @@ void SeqToSVPass::runOnOperation() {
moduleLoweringStates);
patterns.add<CompRegLower<CompRegClockEnabledOp>>(
typeConverter, context, lowerToAlwaysFF, moduleLoweringStates);
patterns.add<FromImmutableLowering>(typeConverter, context,
moduleLoweringStates);
patterns.add<ClockCastLowering<seq::FromClockOp>>(typeConverter, context);
patterns.add<ClockCastLowering<seq::ToClockOp>>(typeConverter, context);
patterns.add<ClockGateLowering>(typeConverter, context);
Expand Down
38 changes: 25 additions & 13 deletions lib/Dialect/Arc/Transforms/LowerState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ Value ClockLowering::materializeValue(Value value) {
return {};
if (auto mapped = materializedValues.lookupOrNull(value))
return mapped;
if (auto fromImmutable = value.getDefiningOp<seq::FromImmutableOp>())
// Immutable value is pre-materialized so directly lookup the input.
return materializedValues.lookup(fromImmutable.getInput());

if (!shouldMaterialize(value))
return value;

Expand Down Expand Up @@ -427,19 +431,27 @@ LogicalResult ModuleLowering::lowerPrimaryOutputs() {
}

LogicalResult ModuleLowering::lowerInitials() {
// Move all operations except for seq.yield to arc.initial op.
for (auto op : moduleOp.getOps<seq::InitialOp>()) {
auto terminator = cast<seq::YieldOp>(op.getBodyBlock()->getTerminator());
getInitial().builder.getBlock()->getOperations().splice(
getInitial().builder.getBlock()->begin(),
op.getBodyBlock()->getOperations());

// Map seq.initial results to operands of the seq.yield op.
for (auto [result, operand] :
llvm::zip(op.getResults(), terminator.getOperands()))
getInitial().materializedValues.map(result, operand);
terminator.erase();
}
// Merge all seq.initial ops into a single seq.initial op.
auto result = circt::seq::mergeInitialOps(moduleOp.getBodyBlock());
if (failed(result))
return moduleOp.emitError() << "initial ops cannot be topologically sorted";

auto initialOp = *result;
if (!initialOp) // There is no seq.initial op.
return success();

// Move the operations of the merged initial op into the builder's block.
auto terminator =
cast<seq::YieldOp>(initialOp.getBodyBlock()->getTerminator());
getInitial().builder.getBlock()->getOperations().splice(
getInitial().builder.getBlock()->begin(),
initialOp.getBodyBlock()->getOperations());

// Map seq.initial results to their corresponding operands.
for (auto [result, operand] :
llvm::zip(initialOp.getResults(), terminator.getOperands()))
getInitial().materializedValues.map(result, operand);
terminator.erase();

return success();
}
Expand Down
Loading

0 comments on commit 79c5799

Please sign in to comment.