Skip to content

Commit

Permalink
Support predicated-receives in proc unroll
Browse files Browse the repository at this point in the history
This is useful for proving equivalence of more complicated procs.

PiperOrigin-RevId: 702794647
  • Loading branch information
allight authored and copybara-github committed Dec 4, 2024
1 parent ed297cc commit 5863be3
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 7 deletions.
30 changes: 23 additions & 7 deletions xls/ir/proc_testutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class UnrollProcVisitor final : public DfsVisitorWithDefault {
token_value_(std::move(token_value)) {}

absl::Status DefaultHandler(Node* n) override {
XLS_RETURN_IF_ERROR(fb_.GetError());
std::vector<Node*> new_ops;
for (Node* old_op : n->operands()) {
XLS_RET_CHECK(values_.contains({old_op, activation_}))
Expand All @@ -91,6 +92,7 @@ class UnrollProcVisitor final : public DfsVisitorWithDefault {
}

absl::Status HandleStateRead(StateRead* state_read) override {
XLS_RETURN_IF_ERROR(fb_.GetError());
if (state_read->GetType()->IsToken()) {
values_[{state_read, activation_}] = fb_.Literal(token_value_);
return absl::OkStatus();
Expand All @@ -102,6 +104,7 @@ class UnrollProcVisitor final : public DfsVisitorWithDefault {
}

absl::Status HandleSend(Send* s) override {
XLS_RETURN_IF_ERROR(fb_.GetError());
values_[{s, activation_}] = fb_.Literal(token_value_);
BValue predicate_value;
BValue data;
Expand All @@ -118,12 +121,22 @@ class UnrollProcVisitor final : public DfsVisitorWithDefault {
return absl::OkStatus();
}

absl::Status HandleNext(Next* n) override { return absl::OkStatus(); }
absl::Status HandleNext(Next* n) override {
XLS_RETURN_IF_ERROR(fb_.GetError());
return absl::OkStatus();
}

absl::Status HandleReceive(Receive* r) override {
BValue real_data = fb_.Param(
absl::StrFormat("%s_act%d_read", r->channel_name(), activation_),
r->GetPayloadType());
XLS_RETURN_IF_ERROR(fb_.GetError());
BValue real_data;
if (recv_state_.contains({r->channel_name(), activation_})) {
real_data = recv_state_.at({r->channel_name(), activation_});
} else {
real_data = fb_.Param(
absl::StrFormat("%s_act%d_read", r->channel_name(), activation_),
r->GetPayloadType());
recv_state_[{r->channel_name(), activation_}] = real_data;
}
std::vector<BValue> result_values{fb_.Literal(token_value_)};
if (r->predicate()) {
result_values.push_back(fb_.Select(
Expand All @@ -139,9 +152,8 @@ class UnrollProcVisitor final : public DfsVisitorWithDefault {
activation_),
fb_.package()->GetBitsType(1)));
}
values_[{r, activation_}] = fb_.Tuple(result_values);
BValue res = values_[{r, activation_}];
VLOG(2) << "got " << r << " -> " << res;
values_[{r, activation_}] = fb_.Tuple(std::move(result_values));
VLOG(2) << "got " << r << " -> " << values_[{r, activation_}];
return absl::OkStatus();
}

Expand All @@ -155,6 +167,7 @@ class UnrollProcVisitor final : public DfsVisitorWithDefault {
}

absl::Status HandleAfterAll(AfterAll* aa) override {
XLS_RETURN_IF_ERROR(fb_.GetError());
// TODO: https://github.com/google/xls/issues/1375 - It would be nice to
// record this for real. The issue is that we'd need to figure out some way
// to flatten the tree in a consistent way.
Expand Down Expand Up @@ -190,6 +203,9 @@ class UnrollProcVisitor final : public DfsVisitorWithDefault {
absl::flat_hash_map<NodeActivation, BValue>& values_;
// A map of channel names to values sent on the most recent activation.
absl::flat_hash_map<std::string, BValue> send_state_;
// A map of channel names & activation to the values received on that
// activation.
absl::flat_hash_map<std::pair<std::string, int64_t>, BValue> recv_state_;
// Which activation are we inlining.
int64_t activation_;
// What value should we use for a token.
Expand Down
73 changes: 73 additions & 0 deletions xls/ir/proc_testutils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,5 +348,78 @@ TEST_F(UnrollProcTest, MultiProcsDifferentSizedState) {
EXPECT_THAT(TryProveEquivalence(f1, f2), IsOkAndHolds(IsProvenTrue()));
}

TEST_F(UnrollProcTest, PredicatedReceives) {
auto p = CreatePackage();
FunctionBuilder fb(absl::StrCat(TestName(), "_func"), p.get());
ProcBuilder pb(absl::StrCat(TestName(), "_proc"), p.get());
XLS_ASSERT_OK_AND_ASSIGN(
auto read_ch,
p->CreateStreamingChannel("do_read", ChannelOps::kReceiveOnly,
p->GetBitsType(1)));
XLS_ASSERT_OK_AND_ASSIGN(
auto write_ch,
p->CreateStreamingChannel("do_write", ChannelOps::kReceiveOnly,
p->GetBitsType(1)));
XLS_ASSERT_OK_AND_ASSIGN(
auto bar_ch, p->CreateStreamingChannel("bar_ch", ChannelOps::kReceiveOnly,
p->GetBitsType(4)));
XLS_ASSERT_OK_AND_ASSIGN(
auto ret_ch, p->CreateStreamingChannel("ret_ch", ChannelOps::kSendOnly,
p->GetBitsType(4)));
BValue tok = pb.StateElement("tok", Value::Token());
BValue state = pb.StateElement("cnt", UBits(1, 4));
BValue cont = pb.Receive(read_ch, tok);
BValue recv =
pb.ReceiveIf(bar_ch, pb.TupleIndex(cont, 0), pb.TupleIndex(cont, 1));
BValue nxt_val = pb.Add(state, pb.TupleIndex(recv, 1));
BValue do_write = pb.Receive(write_ch, pb.TupleIndex(recv, 0));
BValue final_tok = pb.SendIf(ret_ch, pb.TupleIndex(do_write, 0),
pb.TupleIndex(do_write, 1), nxt_val);
pb.Next(state, nxt_val);
pb.Next(tok, final_tok);
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());

BValue act_read_1 = fb.Param("do_read_ch_act0_read", p->GetBitsType(1));
BValue read_1 = fb.Param("bar_ch_act0_read", p->GetBitsType(4));
BValue act_write_1 = fb.Param("do_write_ch_act0_read", p->GetBitsType(1));
BValue act_read_2 = fb.Param("do_read_ch_act1_read", p->GetBitsType(1));
BValue read_2 = fb.Param("bar_ch_act1_read", p->GetBitsType(4));
BValue act_write_2 = fb.Param("do_write_ch_act1_read", p->GetBitsType(1));
BValue act_read_3 = fb.Param("do_read_ch_act2_read", p->GetBitsType(1));
BValue read_3 = fb.Param("bar_ch_act2_read", p->GetBitsType(4));
BValue act_write_3 = fb.Param("do_write_ch_act2_read", p->GetBitsType(1));
BValue act_read_4 = fb.Param("do_read_ch_act3_read", p->GetBitsType(1));
BValue read_4 = fb.Param("bar_ch_act3_read", p->GetBitsType(4));
BValue act_write_4 = fb.Param("do_write_ch_act3_read", p->GetBitsType(1));
BValue lit_zero = fb.Literal(UBits(0, 4));
BValue st_1 = fb.Literal(UBits(1, 4));
BValue st_2 = fb.Add(st_1, fb.Select(act_read_1, read_1, lit_zero));
BValue st_3 = fb.Add(st_2, fb.Select(act_read_2, read_2, lit_zero));
BValue st_4 = fb.Add(st_3, fb.Select(act_read_3, read_3, lit_zero));
BValue st_5 = fb.Add(st_4, fb.Select(act_read_4, read_4, lit_zero));

auto single_activation = [&](auto act_write, auto next_st) {
return fb.Tuple(
{fb.Tuple({act_write, fb.Select(act_write, next_st, lit_zero)})});
};

fb.Tuple({
single_activation(act_write_1, st_2),
single_activation(act_write_2, st_3),
single_activation(act_write_3, st_4),
single_activation(act_write_4, st_5),
});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

XLS_ASSERT_OK_AND_ASSIGN(
Function * converted,
UnrollProcToFunction(proc, 4, /*include_state=*/false));

RecordProperty("func", f->DumpIr());
RecordProperty("proc", proc->DumpIr());
RecordProperty("converted", converted->DumpIr());
EXPECT_THAT(TryProveEquivalence(f, converted), IsOkAndHolds(IsProvenTrue()));
}

} // namespace
} // namespace xls

0 comments on commit 5863be3

Please sign in to comment.