Skip to content

Commit

Permalink
[XLS] Support StateRead predicates in state legalization
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709185318
  • Loading branch information
ericastor authored and copybara-github committed Dec 24, 2024
1 parent 4dc4230 commit b959a4e
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 4 deletions.
2 changes: 2 additions & 0 deletions xls/scheduling/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,9 @@ cc_library(
"//xls/ir:op",
"//xls/ir:state_element",
"//xls/solvers:z3_ir_translator",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
Expand Down
133 changes: 129 additions & 4 deletions xls/scheduling/proc_state_legalization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#include <variant>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -63,6 +65,91 @@ absl::StatusOr<bool> ModernizeNextValues(Proc* proc) {
return proc->GetStateElementCount() > 0;
}

class StateReadPredicateRemover : public Proc::StateElementTransformer {
public:
~StateReadPredicateRemover() override = default;

absl::StatusOr<std::optional<Node*>> TransformReadPredicate(
Proc* proc, StateRead* old_state_read) override {
return std::nullopt;
}
};

// Ensure that `state_read` is either unpredicated or has a predicate that is
// true whenever any of its corresponding `next_value`s are active.
absl::StatusOr<bool> LegalizeStateReadPredicate(
Proc* proc, StateElement* state_element,
const SchedulingPassOptions& options) {
StateRead* state_read = proc->GetStateRead(state_element);
const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values =
proc->next_values(state_read);
if (!state_read->predicate().has_value() || next_values.empty()) {
// No predicate; nothing to do.
return false;
}

if (absl::c_any_of(next_values, [](const Next* next) {
return !next->predicate().has_value();
})) {
StateReadPredicateRemover predicate_remover;
XLS_RETURN_IF_ERROR(proc->TransformStateElement(
state_read,
state_read->state_element()->initial_value(),
predicate_remover)
.status());
return true;
}

std::vector<Node*> predicates;
absl::flat_hash_set<Node*> predicates_set;
predicates.reserve(1 + next_values.size());
predicates_set.reserve(next_values.size());
for (Next* next : next_values) {
CHECK(next->predicate().has_value());
predicates.push_back(*next->predicate());
predicates_set.insert(*next->predicate());
}

Node* state_read_predicate = *state_read->predicate();
if (state_read_predicate->op() == Op::kOr &&
predicates_set ==
absl::flat_hash_set<Node*>(predicates.begin(), predicates.end())) {
// The predicate is already trivially correct; nothing to do.
return false;
}
if (predicates_set.size() == 1 &&
predicates.front() == state_read_predicate) {
// The predicate is already trivially correct; nothing to do.
return false;
}

predicates.insert(predicates.begin(), state_read_predicate);
XLS_ASSIGN_OR_RETURN(
Node * new_predicate,
NaryOrIfNeeded(proc, predicates, /*name=*/"", state_read->loc()));
XLS_RETURN_IF_ERROR(state_read->ReplaceOperandNumber(
*state_read->predicate_operand_number(), new_predicate));
return true;
}

absl::StatusOr<bool> LegalizeStateReadPredicates(
Proc* proc, const SchedulingPassOptions& options) {
bool changed = false;

for (StateElement* state_element : proc->StateElements()) {
XLS_ASSIGN_OR_RETURN(
bool state_read_changed,
LegalizeStateReadPredicate(proc, state_element, options));
if (state_read_changed) {
VLOG(4) << "Generalized read predicate for state element: "
<< state_element->name();
changed = true;
}
}

return changed;
}

absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
StateElement* state_element,
const SchedulingPassOptions& options) {
Expand All @@ -83,7 +170,7 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
XLS_RETURN_IF_ERROR(proc->MakeNodeWithName<Next>(
state_read->loc(), /*state_read=*/state_read,
/*value=*/state_read,
/*predicate=*/std::nullopt,
/*predicate=*/state_read->predicate(),
absl::StrCat(state_element->name(), "_default"))
.status());
return true;
Expand All @@ -101,6 +188,22 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
continue;
}

if (state_read->predicate().has_value() && predicate->OpIn({Op::kAnd}) &&
predicate->operands().size() == 2) {
// Check to see if this is just an `and` with the state read predicate. If
// so, take the other operand & see if it's a not/nor of the other
// conditions.
if (predicate->operand(0) == *state_read->predicate()) {
predicate = predicate->operand(1);
} else if (predicate->operand(1) == *state_read->predicate()) {
predicate = predicate->operand(0);
} else {
// It's not, so we can't trivially recognize it as being of the right
// form.
continue;
}
}

if (!predicate->OpIn({Op::kNot, Op::kNor})) {
continue;
}
Expand Down Expand Up @@ -144,13 +247,21 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
// Explicitly mark the param as unchanged when no other `next_value` node is
// active.
XLS_ASSIGN_OR_RETURN(
Node * all_predicates_false,
Node * default_predicate,
NaryNorIfNeeded(proc, std::vector(predicates.begin(), predicates.end()),
/*name=*/"", state_read->loc()));
if (state_read->predicate().has_value()) {
XLS_ASSIGN_OR_RETURN(
default_predicate,
proc->MakeNode<NaryOp>(
state_read->loc(),
absl::MakeConstSpan({*state_read->predicate(), default_predicate}),
Op::kAnd));
}
XLS_RETURN_IF_ERROR(proc->MakeNodeWithName<Next>(
state_read->loc(), /*state_read=*/state_read,
/*value=*/state_read,
/*predicate=*/all_predicates_false,
/*predicate=*/default_predicate,
absl::StrCat(state_element->name(), "_default"))
.status());
return true;
Expand Down Expand Up @@ -191,7 +302,21 @@ absl::StatusOr<bool> ProcStateLegalizationPass::RunOnFunctionBaseInternal(
return ModernizeNextValues(proc);
}

return AddDefaultNextValues(proc, options);
bool changed = false;

XLS_ASSIGN_OR_RETURN(bool read_predicates_changed,
LegalizeStateReadPredicates(proc, options));
if (read_predicates_changed) {
changed = true;
}

XLS_ASSIGN_OR_RETURN(bool default_nexts_added,
AddDefaultNextValues(proc, options));
if (default_nexts_added) {
changed = true;
}

return changed;
}

} // namespace xls
96 changes: 96 additions & 0 deletions xls/scheduling/proc_state_legalization_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "xls/scheduling/proc_state_legalization_pass.h"

#include <optional>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/status/status_matchers.h"
Expand Down Expand Up @@ -269,5 +271,99 @@ TEST_F(ProcStateLegalizationPassTest,
m::Nor(positive_predicate.node(), negative_predicate.node()))));
}

TEST_F(ProcStateLegalizationPassTest, ProcWithPredicatedStateRead) {
auto p = CreatePackage();
ProcBuilder pb("p", p.get());
BValue x = pb.StateElement("x", Value(UBits(0, 32)));
BValue x_even =
pb.Eq(pb.UMod(x, pb.Literal(UBits(2, 32))), pb.Literal(UBits(0, 32)));
BValue x_multiple_of_3 =
pb.Eq(pb.UMod(x, pb.Literal(UBits(3, 32))), pb.Literal(UBits(0, 32)));
BValue y = pb.StateElement("y", Value(UBits(0, 32)),
/*read_predicate=*/x_even);
pb.Next(x, pb.Add(x, pb.Literal(UBits(1, 32))));
pb.Next(y, pb.Add(y, pb.Literal(UBits(1, 32))), x_multiple_of_3);
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());

ASSERT_THAT(Run(proc), IsOkAndHolds(true));

EXPECT_EQ(proc->GetStateRead(*proc->GetStateElement("x"))->predicate(),
std::nullopt);
EXPECT_THAT(
proc->GetStateRead(*proc->GetStateElement("y"))->predicate(),
Optional(m::Or(
m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), m::Literal(0)),
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0)))));
EXPECT_THAT(
proc->next_values(proc->GetStateRead(*proc->GetStateElement("y"))),
UnorderedElementsAre(
m::Next(
m::StateRead("y"), m::Add(m::StateRead("y"), m::Literal(1)),
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0))),
m::Next(m::StateRead("y"), m::StateRead("y"),
m::And(m::Or(m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)),
m::Literal(0)),
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)),
m::Literal(0))),
m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)),
m::Literal(0)))))));
}

TEST_F(ProcStateLegalizationPassTest,
ProcWithCorrectlyPredicatedStateReadAndNoDefaultNextNeeded) {
auto p = CreatePackage();
ProcBuilder pb("p", p.get());
BValue x = pb.StateElement("x", Value(UBits(0, 32)));
BValue x_multiple_of_3 =
pb.Eq(pb.UMod(x, pb.Literal(UBits(3, 32))), pb.Literal(UBits(0, 32)));
BValue x_not_multiple_of_3 = pb.Not(x_multiple_of_3);
BValue disjunction = pb.Or(x_multiple_of_3, x_not_multiple_of_3);
BValue y = pb.StateElement("y", Value(UBits(0, 32)),
/*read_predicate=*/disjunction);
pb.Next(x, pb.Add(x, pb.Literal(UBits(1, 32))));
pb.Next(y, pb.Add(y, pb.Literal(UBits(1, 32))), x_multiple_of_3);
pb.Next(y, y, x_not_multiple_of_3);
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());

ASSERT_THAT(Run(proc), IsOkAndHolds(false));
}

TEST_F(ProcStateLegalizationPassTest,
ProcWithPredicatedStateReadAndNoDefaultNextNeeded) {
auto p = CreatePackage();
ProcBuilder pb("p", p.get());
BValue x = pb.StateElement("x", Value(UBits(0, 32)));
BValue x_even =
pb.Eq(pb.UMod(x, pb.Literal(UBits(2, 32))), pb.Literal(UBits(0, 32)));
BValue y = pb.StateElement("y", Value(UBits(0, 32)),
/*read_predicate=*/x_even);
BValue x_multiple_of_3 =
pb.Eq(pb.UMod(x, pb.Literal(UBits(3, 32))), pb.Literal(UBits(0, 32)));
BValue x_not_multiple_of_3 = pb.Not(x_multiple_of_3);
pb.Next(x, pb.Add(x, pb.Literal(UBits(1, 32))));
pb.Next(y, pb.Add(y, pb.Literal(UBits(1, 32))), x_multiple_of_3);
pb.Next(y, y, x_not_multiple_of_3);
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());

ASSERT_THAT(Run(proc), IsOkAndHolds(true));

EXPECT_THAT(
proc->GetStateRead(*proc->GetStateElement("y"))->predicate(),
Optional(
m::Or(m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), m::Literal(0)),
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0)),
m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)),
m::Literal(0))))));
EXPECT_THAT(
proc->next_values(proc->GetStateRead(*proc->GetStateElement("y"))),
UnorderedElementsAre(
m::Next(
m::StateRead("y"), m::Add(m::StateRead("y"), m::Literal(1)),
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0))),
m::Next(m::StateRead("y"), m::StateRead("y"),
m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)),
m::Literal(0))))));
}

} // namespace
} // namespace xls

0 comments on commit b959a4e

Please sign in to comment.