Skip to content

Commit

Permalink
Add a depth limit to next-value splitting
Browse files Browse the repository at this point in the history
Without this, certain inputs (especially those produced by XLS[cc]) can end up spending absurd amounts of time subdividing next-value nodes.

PiperOrigin-RevId: 611463956
  • Loading branch information
ericastor authored and copybara-github committed Feb 29, 2024
1 parent f0e5220 commit 42b0727
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 48 deletions.
1 change: 1 addition & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2654,6 +2654,7 @@ cc_library(
":optimization_pass",
":optimization_pass_registry",
":pass_base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
111 changes: 73 additions & 38 deletions xls/passes/next_value_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -67,14 +68,18 @@ absl::StatusOr<bool> ModernizeNextValues(Proc* proc) {
return proc->GetStateElementCount() > 0;
}

absl::Status RemoveNextValue(Proc* proc, Next* next) {
absl::Status RemoveNextValue(Proc* proc, Next* next,
absl::flat_hash_map<Next*, int64_t>& split_depth) {
XLS_RETURN_IF_ERROR(
next->ReplaceUsesWithNew<Literal>(Value::Tuple({})).status());
if (auto it = split_depth.find(next); it != split_depth.end()) {
split_depth.erase(it);
}
return proc->RemoveNode(next);
}

absl::StatusOr<std::optional<std::vector<Next*>>> RemoveLiteralPredicate(
Proc* proc, Next* next) {
Proc* proc, Next* next, absl::flat_hash_map<Next*, int64_t>& split_depth) {
if (!next->predicate().has_value()) {
return std::nullopt;
}
Expand All @@ -87,7 +92,7 @@ absl::StatusOr<std::optional<std::vector<Next*>>> RemoveLiteralPredicate(
if (literal_predicate->value().IsAllZeros()) {
XLS_VLOG(2) << "Identified node as dead due to zero predicate; removing: "
<< *next;
XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next));
XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next, split_depth));
return std::vector<Next*>();
}
XLS_VLOG(2) << "Identified node as always live; removing predicate: "
Expand All @@ -100,12 +105,16 @@ absl::StatusOr<std::optional<std::vector<Next*>>> RemoveLiteralPredicate(
if (next->HasAssignedName()) {
new_next->SetName(next->GetName());
}
XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next));
if (split_depth.contains(next)) {
split_depth[new_next] = split_depth[next];
}
XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next, split_depth));
return std::vector<Next*>({new_next});
}

absl::StatusOr<std::optional<std::vector<Next*>>> SplitSmallSelect(
Proc* proc, Next* next, const OptimizationPassOptions& options) {
Proc* proc, Next* next, const OptimizationPassOptions& options,
absl::flat_hash_map<Next*, int64_t>& split_depth) {
if (!options.split_next_value_selects.has_value()) {
return std::nullopt;
}
Expand All @@ -119,6 +128,11 @@ absl::StatusOr<std::optional<std::vector<Next*>>> SplitSmallSelect(
return std::nullopt;
}

int64_t depth = 1;
if (auto it = split_depth.find(next); it != split_depth.end()) {
depth = it->second + 1;
}

std::vector<Next*> new_next_values;
for (int64_t i = 0; i < selected_value->cases().size(); ++i) {
XLS_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -149,6 +163,7 @@ absl::StatusOr<std::optional<std::vector<Next*>>> SplitSmallSelect(
/*value=*/selected_value->cases()[i],
predicate, name));
new_next_values.push_back(new_next);
split_depth[new_next] = depth;
}

if (selected_value->default_value().has_value()) {
Expand Down Expand Up @@ -181,19 +196,25 @@ absl::StatusOr<std::optional<std::vector<Next*>>> SplitSmallSelect(
/*value=*/*selected_value->default_value(),
predicate, name));
new_next_values.push_back(new_next);
split_depth[new_next] = depth;
}

XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next));
XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next, split_depth));
return new_next_values;
}

absl::StatusOr<std::optional<std::vector<Next*>>> SplitPrioritySelect(
Proc* proc, Next* next) {
Proc* proc, Next* next, absl::flat_hash_map<Next*, int64_t>& split_depth) {
if (!next->value()->Is<PrioritySelect>()) {
return std::nullopt;
}
PrioritySelect* selected_value = next->value()->As<PrioritySelect>();

int64_t depth = 1;
if (auto it = split_depth.find(next); it != split_depth.end()) {
depth = it->second + 1;
}

std::vector<Next*> new_next_values;
for (int64_t i = 0; i < selected_value->cases().size(); ++i) {
absl::InlinedVector<Node*, 3> all_clauses;
Expand Down Expand Up @@ -224,6 +245,7 @@ absl::StatusOr<std::optional<std::vector<Next*>>> SplitPrioritySelect(
/*value=*/selected_value->get_case(i),
/*predicate=*/case_predicate, name));
new_next_values.push_back(new_next);
split_depth[new_next] = depth;
}

// Default case; if all bits of the input are zero, `priority_sel` returns
Expand Down Expand Up @@ -259,13 +281,14 @@ absl::StatusOr<std::optional<std::vector<Next*>>> SplitPrioritySelect(
/*value=*/default_value,
/*predicate=*/default_predicate, name));
new_next_values.push_back(new_next);
split_depth[new_next] = depth;

XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next));
XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next, split_depth));
return new_next_values;
}

absl::StatusOr<std::optional<std::vector<Next*>>> SplitSafeOneHotSelect(
Proc* proc, Next* next) {
Proc* proc, Next* next, absl::flat_hash_map<Next*, int64_t>& split_depth) {
if (!next->value()->Is<OneHotSelect>()) {
return std::nullopt;
}
Expand All @@ -276,6 +299,11 @@ absl::StatusOr<std::optional<std::vector<Next*>>> SplitSafeOneHotSelect(
return std::nullopt;
}

int64_t depth = 1;
if (auto it = split_depth.find(next); it != split_depth.end()) {
depth = it->second + 1;
}

std::vector<Next*> new_next_values;
for (int64_t i = 0; i < selected_value->cases().size(); ++i) {
XLS_ASSIGN_OR_RETURN(
Expand All @@ -302,8 +330,9 @@ absl::StatusOr<std::optional<std::vector<Next*>>> SplitSafeOneHotSelect(
/*value=*/selected_value->get_case(i),
/*predicate=*/case_predicate, name));
new_next_values.push_back(new_next);
split_depth[new_next] = depth;
}
XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next));
XLS_RETURN_IF_ERROR(RemoveNextValue(proc, next, split_depth));
return new_next_values;
}

Expand All @@ -321,49 +350,55 @@ absl::StatusOr<bool> NextValueOptimizationPass::RunOnProcInternal(

std::deque<Next*> worklist(proc->next_values().begin(),
proc->next_values().end());
absl::flat_hash_map<Next*, int64_t> split_depth;
while (!worklist.empty()) {
Next* next = worklist.front();
worklist.pop_front();

XLS_ASSIGN_OR_RETURN(
std::optional<std::vector<Next*>> literal_predicate_next_values,
RemoveLiteralPredicate(proc, next));
RemoveLiteralPredicate(proc, next, split_depth));
if (literal_predicate_next_values.has_value()) {
changed = true;
worklist.insert(worklist.end(), literal_predicate_next_values->begin(),
literal_predicate_next_values->end());
continue;
}

XLS_ASSIGN_OR_RETURN(
std::optional<std::vector<Next*>> split_select_next_values,
SplitSmallSelect(proc, next, options));
if (split_select_next_values.has_value()) {
changed = true;
worklist.insert(worklist.end(), split_select_next_values->begin(),
split_select_next_values->end());
continue;
}
if (auto it = split_depth.find(next);
SplitsEnabled(opt_level_) && max_split_depth_ > 0 &&
(it == split_depth.end() || it->second < max_split_depth_)) {
XLS_ASSIGN_OR_RETURN(
std::optional<std::vector<Next*>> split_select_next_values,
SplitSmallSelect(proc, next, options, split_depth));
if (split_select_next_values.has_value()) {
changed = true;
worklist.insert(worklist.end(), split_select_next_values->begin(),
split_select_next_values->end());
continue;
}

XLS_ASSIGN_OR_RETURN(
std::optional<std::vector<Next*>> split_priority_select_next_values,
SplitPrioritySelect(proc, next));
if (split_priority_select_next_values.has_value()) {
changed = true;
worklist.insert(worklist.end(),
split_priority_select_next_values->begin(),
split_priority_select_next_values->end());
continue;
}
XLS_ASSIGN_OR_RETURN(
std::optional<std::vector<Next*>> split_priority_select_next_values,
SplitPrioritySelect(proc, next, split_depth));
if (split_priority_select_next_values.has_value()) {
changed = true;
worklist.insert(worklist.end(),
split_priority_select_next_values->begin(),
split_priority_select_next_values->end());
continue;
}

XLS_ASSIGN_OR_RETURN(
std::optional<std::vector<Next*>> split_one_hot_select_next_values,
SplitSafeOneHotSelect(proc, next));
if (split_one_hot_select_next_values.has_value()) {
changed = true;
worklist.insert(worklist.end(), split_one_hot_select_next_values->begin(),
split_one_hot_select_next_values->end());
continue;
XLS_ASSIGN_OR_RETURN(
std::optional<std::vector<Next*>> split_one_hot_select_next_values,
SplitSafeOneHotSelect(proc, next, split_depth));
if (split_one_hot_select_next_values.has_value()) {
changed = true;
worklist.insert(worklist.end(),
split_one_hot_select_next_values->begin(),
split_one_hot_select_next_values->end());
continue;
}
}
}

Expand Down
14 changes: 12 additions & 2 deletions xls/passes/next_value_optimization_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
#ifndef XLS_PASSES_NEXT_VALUE_OPTIMIZATION_PASS_H_
#define XLS_PASSES_NEXT_VALUE_OPTIMIZATION_PASS_H_

#include <cstdint>
#include <string_view>

#include "absl/status/statusor.h"
#include "xls/ir/proc.h"
#include "xls/passes/optimization_pass.h"
Expand All @@ -37,11 +39,19 @@ namespace xls {
class NextValueOptimizationPass : public OptimizationProcPass {
public:
static constexpr std::string_view kName = "next_value_opt";
NextValueOptimizationPass()
: OptimizationProcPass(kName, "Next Value Optimization") {}

static constexpr int64_t kDefaultMaxSplitDepth = 10;
explicit NextValueOptimizationPass(
int64_t opt_level = kMaxOptLevel,
int64_t max_split_depth = kDefaultMaxSplitDepth)
: OptimizationProcPass(kName, "Next Value Optimization"),
opt_level_(opt_level),
max_split_depth_(max_split_depth) {}
~NextValueOptimizationPass() override = default;

protected:
const int64_t opt_level_;
const int64_t max_split_depth_;
absl::StatusOr<bool> RunOnProcInternal(Proc* proc,
const OptimizationPassOptions& options,
PassResults* results) const override;
Expand Down
46 changes: 40 additions & 6 deletions xls/passes/next_value_optimization_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,16 @@ class NextValueOptimizationPassTest : public IrTestBase {

absl::StatusOr<bool> Run(
Package* p,
std::optional<int64_t> split_next_value_selects = std::nullopt) {
std::optional<int64_t> split_next_value_selects = std::nullopt,
std::optional<int64_t> split_depth_limit = std::nullopt) {
PassResults results;
OptimizationPassOptions options;
options.split_next_value_selects = split_next_value_selects;
return NextValueOptimizationPass().Run(p, options, &results);
return NextValueOptimizationPass(
kMaxOptLevel,
split_depth_limit.value_or(
NextValueOptimizationPass::kDefaultMaxSplitDepth))
.Run(p, options, &results);
}
};

Expand Down Expand Up @@ -255,19 +260,48 @@ TEST_F(NextValueOptimizationPassTest, CascadingSmallSelectsNextValue) {
EXPECT_THAT(Run(p.get(), /*split_next_value_selects=*/2), IsOkAndHolds(true));
EXPECT_THAT(proc->next_values(),
UnorderedElementsAre(
m::Next(m::Param(), m::Literal(2),
m::Next(m::Param("x"), m::Literal(2),
m::And(m::Eq(m::Param("a"), m::Literal(0)),
m::Eq(m::Param("b"), m::Literal(0)))),
m::Next(m::Param(), m::Literal(1),
m::Next(m::Param("x"), m::Literal(1),
m::And(m::Eq(m::Param("a"), m::Literal(0)),
m::Eq(m::Param("b"), m::Literal(1)))),
m::Next(m::Param(), m::Literal(2),
m::Next(m::Param("x"), m::Literal(2),
m::And(m::Eq(m::Param("a"), m::Literal(1)),
m::Eq(m::Param("b"), m::Literal(0)))),
m::Next(m::Param(), m::Literal(3),
m::Next(m::Param("x"), m::Literal(3),
m::And(m::Eq(m::Param("a"), m::Literal(1)),
m::Eq(m::Param("b"), m::Literal(1))))));
}

TEST_F(NextValueOptimizationPassTest,
DepthLimitedCascadingSmallSelectsNextValue) {
auto p = CreatePackage();
ProcBuilder pb("p", "tkn", p.get());
BValue x = pb.StateElement("x", Value(UBits(0, 2)));
BValue a = pb.StateElement("a", Value(UBits(0, 1)));
BValue b = pb.StateElement("b", Value(UBits(0, 1)));
BValue select_b_1 = pb.Select(
b, std::vector{pb.Literal(UBits(2, 2)), pb.Literal(UBits(1, 2))});
BValue select_b_2 = pb.Select(
b, std::vector{pb.Literal(UBits(2, 2)), pb.Literal(UBits(3, 2))});
BValue select_a = pb.Select(a, std::vector{select_b_1, select_b_2});
pb.Next(/*param=*/x, /*value=*/select_a);
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build(pb.GetTokenParam()));

EXPECT_THAT(
Run(p.get(), /*split_next_value_selects=*/2, /*split_depth_limit=*/1),
IsOkAndHolds(true));
EXPECT_THAT(
proc->next_values(),
UnorderedElementsAre(
m::Next(m::Param("x"),
m::Select(m::Param("b"), {m::Literal(2), m::Literal(1)}),
m::Eq(m::Param("a"), m::Literal(0))),
m::Next(m::Param("x"),
m::Select(m::Param("b"), {m::Literal(2), m::Literal(3)}),
m::Eq(m::Param("a"), m::Literal(1)))));
}

} // namespace
} // namespace xls
4 changes: 2 additions & 2 deletions xls/passes/optimization_pass_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ std::unique_ptr<OptimizationCompoundPass> CreateOptimizationPassPipeline(
top->Add<ProcStateFlatteningPass>();
top->Add<IdentityRemovalPass>();
top->Add<DataflowSimplificationPass>();
top->Add<NextValueOptimizationPass>();
top->Add<NextValueOptimizationPass>(std::min(int64_t{3}, opt_level));
top->Add<ProcStateOptimizationPass>();
top->Add<DeadCodeEliminationPass>();

Expand All @@ -218,7 +218,7 @@ std::unique_ptr<OptimizationCompoundPass> CreateOptimizationPassPipeline(

top->Add<UselessAssertRemovalPass>();
top->Add<UselessIORemovalPass>();
top->Add<NextValueOptimizationPass>();
top->Add<NextValueOptimizationPass>(std::min(int64_t{3}, opt_level));
top->Add<ProcStateOptimizationPass>();
top->Add<DeadCodeEliminationPass>();

Expand Down

0 comments on commit 42b0727

Please sign in to comment.