Skip to content

Commit

Permalink
Extend dynamic_bit_slice & bit_slice_update simplification to mor…
Browse files Browse the repository at this point in the history
…e cases

Allow conversion of `dynamic_bit_slice` and `bit_slice_update` to "array" operations whenever the index is scaled by the width of the slice/update, even if the width doesn't divide evenly into the size of the bit vector.

PiperOrigin-RevId: 611555239
  • Loading branch information
ericastor authored and copybara-github committed Feb 29, 2024
1 parent 4190b35 commit 247f2c1
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 63 deletions.
2 changes: 1 addition & 1 deletion xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ cc_library(
":optimization_pass_registry",
":pass_base",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/numeric:bits",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
Expand Down
70 changes: 48 additions & 22 deletions xls/passes/bit_slice_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -623,11 +624,6 @@ absl::StatusOr<bool> SimplifyLiteralDynamicBitSlice(
absl::StatusOr<bool> SimplifyScaledDynamicBitSlice(DynamicBitSlice* bit_slice) {
int64_t bit_count = bit_slice->to_slice()->BitCountOrDie();
int64_t width = bit_slice->width();
// TODO(epastor): Remove this restriction by padding the last array element
// with zeros as needed.
if (bit_count % width != 0) {
return false;
}
Node* start = bit_slice->start();

XLS_ASSIGN_OR_RETURN(std::optional<Node*> index,
Expand All @@ -637,14 +633,27 @@ absl::StatusOr<bool> SimplifyScaledDynamicBitSlice(DynamicBitSlice* bit_slice) {
}

std::vector<Node*> array_elements;
array_elements.reserve(bit_count / width);
array_elements.reserve((bit_count / width) +
static_cast<int64_t>(bit_count % width != 0));
for (int64_t element_start = 0; element_start < bit_count;
element_start += width) {
XLS_ASSIGN_OR_RETURN(Node * array_element,
bit_slice->function_base()->MakeNode<BitSlice>(
bit_slice->loc(), bit_slice->to_slice(),
/*start=*/element_start,
/*width=*/width));
Node* array_element;
if (element_start + width <= bit_count) {
XLS_ASSIGN_OR_RETURN(array_element,
bit_slice->function_base()->MakeNode<BitSlice>(
bit_slice->loc(), bit_slice->to_slice(),
/*start=*/element_start,
/*width=*/width));
} else {
XLS_ASSIGN_OR_RETURN(Node * slice,
bit_slice->function_base()->MakeNode<BitSlice>(
bit_slice->loc(), bit_slice->to_slice(),
/*start=*/element_start,
/*width=*/bit_count - element_start));
XLS_ASSIGN_OR_RETURN(array_element,
bit_slice->function_base()->MakeNode<ExtendOp>(
SourceInfo(), slice, width, Op::kZeroExt));
}
array_elements.push_back(array_element);
}

Expand Down Expand Up @@ -697,11 +706,6 @@ absl::StatusOr<bool> SimplifyDynamicBitSlice(DynamicBitSlice* bit_slice) {
absl::StatusOr<bool> SimplifyScaledBitSliceUpdate(BitSliceUpdate* update) {
int64_t bit_count = update->to_update()->BitCountOrDie();
int64_t width = update->update_value()->BitCountOrDie();
// TODO(epastor): Remove this restriction by padding the last array element
// with zeros as needed, then removing them from the result.
if (bit_count % width != 0) {
return false;
}
Node* start = update->start();

XLS_ASSIGN_OR_RETURN(std::optional<Node*> index,
Expand All @@ -711,14 +715,27 @@ absl::StatusOr<bool> SimplifyScaledBitSliceUpdate(BitSliceUpdate* update) {
}

std::vector<Node*> array_elements;
array_elements.reserve(bit_count / width);
array_elements.reserve((bit_count / width) +
static_cast<int64_t>(bit_count % width != 0));
for (int64_t element_start = 0; element_start < bit_count;
element_start += width) {
XLS_ASSIGN_OR_RETURN(Node * array_element,
update->function_base()->MakeNode<BitSlice>(
update->loc(), update->to_update(),
/*start=*/element_start,
/*width=*/width));
Node* array_element;
if (element_start + width <= bit_count) {
XLS_ASSIGN_OR_RETURN(array_element,
update->function_base()->MakeNode<BitSlice>(
update->loc(), update->to_update(),
/*start=*/element_start,
/*width=*/width));
} else {
XLS_ASSIGN_OR_RETURN(Node * slice,
update->function_base()->MakeNode<BitSlice>(
update->loc(), update->to_update(),
/*start=*/element_start,
/*width=*/bit_count - element_start));
XLS_ASSIGN_OR_RETURN(array_element,
update->function_base()->MakeNode<ExtendOp>(
SourceInfo(), slice, width, Op::kZeroExt));
}
array_elements.push_back(array_element);
}
XLS_ASSIGN_OR_RETURN(Node * array, update->function_base()->MakeNode<Array>(
Expand All @@ -744,6 +761,15 @@ absl::StatusOr<bool> SimplifyScaledBitSliceUpdate(BitSliceUpdate* update) {
array_update->function_base()->MakeNode<ArrayIndex>(
array_update->loc(), array_update,
/*indices=*/std::vector<Node*>({element_index})));
if (bit_count - (i * width) < width) {
CHECK_EQ(i, array_elements.size() - 1);

// Disregard any bits past the end of the original bit vector.
XLS_ASSIGN_OR_RETURN(updated_array_element,
array_update->function_base()->MakeNode<BitSlice>(
SourceInfo(), updated_array_element, /*start=*/0,
/*width=*/bit_count - (i * width)));
}
updated_array_elements.push_back(updated_array_element);
}

Expand Down
92 changes: 52 additions & 40 deletions xls/passes/bit_slice_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,9 @@ TEST_F(BitSliceSimplificationPassTest, BitSliceUpdateOutOfBounds) {
TEST_F(BitSliceSimplificationPassTest, DynamicBitSliceWithScaledIndex) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u30 = p->GetBitsType(30);
Type* u23 = p->GetBitsType(23);
Type* u2 = p->GetBitsType(2);
BValue x = fb.Param("x", u30);
BValue x = fb.Param("x", u23);
BValue z = fb.Param("z", u2);
BValue index = fb.UMul(fb.ZeroExtend(z, 5), fb.Literal(UBits(10, 5)));
fb.DynamicBitSlice(x, index, /*width=*/10);
Expand All @@ -705,11 +705,14 @@ TEST_F(BitSliceSimplificationPassTest, DynamicBitSliceWithScaledIndex) {

ASSERT_THAT(
f->return_value(),
m::Select(m::ZeroExt(m::Param("z")),
{m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/10),
m::BitSlice(m::Param("x"), /*start=*/10, /*width=*/10),
m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/10)},
m::Literal(0)));
m::Select(
m::ZeroExt(m::Param("z")),
{
m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/10),
m::BitSlice(m::Param("x"), /*start=*/10, /*width=*/10),
m::ZeroExt(m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/3)),
},
m::Literal(0)));
}

TEST_F(BitSliceSimplificationPassTest,
Expand All @@ -731,14 +734,16 @@ TEST_F(BitSliceSimplificationPassTest,
f->return_value(),
m::Select(m::BitSlice(m::UMul(m::ZeroExt(m::Param("z")), m::Literal(4)),
/*start=*/2, /*width=*/3),
{m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4)}));
{
m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4),
}));
}

TEST_F(BitSliceSimplificationPassTest, DynamicBitSliceWithShiftedIndex) {
Expand All @@ -759,14 +764,16 @@ TEST_F(BitSliceSimplificationPassTest, DynamicBitSliceWithShiftedIndex) {
f->return_value(),
m::Select(m::BitSlice(m::Shll(m::ZeroExt(m::Param("z")), m::Literal(2)),
/*start=*/2, /*width=*/3),
{m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4)}));
{
m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4),
}));
}

TEST_F(BitSliceSimplificationPassTest,
Expand All @@ -787,23 +794,25 @@ TEST_F(BitSliceSimplificationPassTest,
ASSERT_THAT(
f->return_value(),
m::Select(m::Concat(m::Param("z")),
{m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4)}));
{
m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4),
m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4),
}));
}

TEST_F(BitSliceSimplificationPassTest, BitSliceUpdateWithScaledIndex) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u30 = p->GetBitsType(30);
Type* u23 = p->GetBitsType(23);
Type* u10 = p->GetBitsType(10);
Type* u2 = p->GetBitsType(2);
BValue x = fb.Param("x", u30);
BValue x = fb.Param("x", u23);
BValue y = fb.Param("y", u10);
BValue z = fb.Param("z", u2);
BValue index = fb.UMul(fb.ZeroExtend(z, 5), fb.Literal(UBits(10, 5)));
Expand All @@ -813,15 +822,18 @@ TEST_F(BitSliceSimplificationPassTest, BitSliceUpdateWithScaledIndex) {
solvers::z3::ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(f), IsOkAndHolds(true));

auto array = m::Array(m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/10),
m::BitSlice(m::Param("x"), /*start=*/10, /*width=*/10),
m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/10));
auto array = m::Array(
m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/10),
m::BitSlice(m::Param("x"), /*start=*/10, /*width=*/10),
m::ZeroExt(m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/3)));
auto array_update =
m::ArrayUpdate(array, m::Param("y"), {m::ZeroExt(m::Param("z"))});
ASSERT_THAT(f->return_value(),
m::Concat(m::ArrayIndex(array_update, {m::Literal(2)}),
m::ArrayIndex(array_update, {m::Literal(1)}),
m::ArrayIndex(array_update, {m::Literal(0)})));
ASSERT_THAT(
f->return_value(),
m::Concat(m::BitSlice(m::ArrayIndex(array_update, {m::Literal(2)}),
/*start=*/0, /*width=*/3),
m::ArrayIndex(array_update, {m::Literal(1)}),
m::ArrayIndex(array_update, {m::Literal(0)})));
}

TEST_F(BitSliceSimplificationPassTest,
Expand Down

0 comments on commit 247f2c1

Please sign in to comment.