diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 2b724286ab..e12b47be10 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -348,7 +348,7 @@ cc_library( ":optimization_pass_registry", ":pass_base", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", diff --git a/xls/passes/bit_slice_simplification_pass.cc b/xls/passes/bit_slice_simplification_pass.cc index 4f7249556a..ef79622ee9 100644 --- a/xls/passes/bit_slice_simplification_pass.cc +++ b/xls/passes/bit_slice_simplification_pass.cc @@ -24,6 +24,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -623,11 +624,6 @@ absl::StatusOr SimplifyLiteralDynamicBitSlice( absl::StatusOr SimplifyScaledDynamicBitSlice(DynamicBitSlice* bit_slice) { int64_t bit_count = bit_slice->to_slice()->BitCountOrDie(); int64_t width = bit_slice->width(); - // TODO(epastor): Remove this restriction by padding the last array element - // with zeros as needed. - if (bit_count % width != 0) { - return false; - } Node* start = bit_slice->start(); XLS_ASSIGN_OR_RETURN(std::optional index, @@ -637,14 +633,27 @@ absl::StatusOr SimplifyScaledDynamicBitSlice(DynamicBitSlice* bit_slice) { } std::vector array_elements; - array_elements.reserve(bit_count / width); + array_elements.reserve((bit_count / width) + + static_cast(bit_count % width != 0)); for (int64_t element_start = 0; element_start < bit_count; element_start += width) { - XLS_ASSIGN_OR_RETURN(Node * array_element, - bit_slice->function_base()->MakeNode( - bit_slice->loc(), bit_slice->to_slice(), - /*start=*/element_start, - /*width=*/width)); + Node* array_element; + if (element_start + width <= bit_count) { + XLS_ASSIGN_OR_RETURN(array_element, + bit_slice->function_base()->MakeNode( + bit_slice->loc(), bit_slice->to_slice(), + /*start=*/element_start, + /*width=*/width)); + } else { + XLS_ASSIGN_OR_RETURN(Node * slice, + bit_slice->function_base()->MakeNode( + bit_slice->loc(), bit_slice->to_slice(), + /*start=*/element_start, + /*width=*/bit_count - element_start)); + XLS_ASSIGN_OR_RETURN(array_element, + bit_slice->function_base()->MakeNode( + SourceInfo(), slice, width, Op::kZeroExt)); + } array_elements.push_back(array_element); } @@ -697,11 +706,6 @@ absl::StatusOr SimplifyDynamicBitSlice(DynamicBitSlice* bit_slice) { absl::StatusOr SimplifyScaledBitSliceUpdate(BitSliceUpdate* update) { int64_t bit_count = update->to_update()->BitCountOrDie(); int64_t width = update->update_value()->BitCountOrDie(); - // TODO(epastor): Remove this restriction by padding the last array element - // with zeros as needed, then removing them from the result. - if (bit_count % width != 0) { - return false; - } Node* start = update->start(); XLS_ASSIGN_OR_RETURN(std::optional index, @@ -711,14 +715,27 @@ absl::StatusOr SimplifyScaledBitSliceUpdate(BitSliceUpdate* update) { } std::vector array_elements; - array_elements.reserve(bit_count / width); + array_elements.reserve((bit_count / width) + + static_cast(bit_count % width != 0)); for (int64_t element_start = 0; element_start < bit_count; element_start += width) { - XLS_ASSIGN_OR_RETURN(Node * array_element, - update->function_base()->MakeNode( - update->loc(), update->to_update(), - /*start=*/element_start, - /*width=*/width)); + Node* array_element; + if (element_start + width <= bit_count) { + XLS_ASSIGN_OR_RETURN(array_element, + update->function_base()->MakeNode( + update->loc(), update->to_update(), + /*start=*/element_start, + /*width=*/width)); + } else { + XLS_ASSIGN_OR_RETURN(Node * slice, + update->function_base()->MakeNode( + update->loc(), update->to_update(), + /*start=*/element_start, + /*width=*/bit_count - element_start)); + XLS_ASSIGN_OR_RETURN(array_element, + update->function_base()->MakeNode( + SourceInfo(), slice, width, Op::kZeroExt)); + } array_elements.push_back(array_element); } XLS_ASSIGN_OR_RETURN(Node * array, update->function_base()->MakeNode( @@ -744,6 +761,15 @@ absl::StatusOr SimplifyScaledBitSliceUpdate(BitSliceUpdate* update) { array_update->function_base()->MakeNode( array_update->loc(), array_update, /*indices=*/std::vector({element_index}))); + if (bit_count - (i * width) < width) { + CHECK_EQ(i, array_elements.size() - 1); + + // Disregard any bits past the end of the original bit vector. + XLS_ASSIGN_OR_RETURN(updated_array_element, + array_update->function_base()->MakeNode( + SourceInfo(), updated_array_element, /*start=*/0, + /*width=*/bit_count - (i * width))); + } updated_array_elements.push_back(updated_array_element); } diff --git a/xls/passes/bit_slice_simplification_pass_test.cc b/xls/passes/bit_slice_simplification_pass_test.cc index 1fc4fa7cbc..1931b1a4be 100644 --- a/xls/passes/bit_slice_simplification_pass_test.cc +++ b/xls/passes/bit_slice_simplification_pass_test.cc @@ -692,9 +692,9 @@ TEST_F(BitSliceSimplificationPassTest, BitSliceUpdateOutOfBounds) { TEST_F(BitSliceSimplificationPassTest, DynamicBitSliceWithScaledIndex) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get()); - Type* u30 = p->GetBitsType(30); + Type* u23 = p->GetBitsType(23); Type* u2 = p->GetBitsType(2); - BValue x = fb.Param("x", u30); + BValue x = fb.Param("x", u23); BValue z = fb.Param("z", u2); BValue index = fb.UMul(fb.ZeroExtend(z, 5), fb.Literal(UBits(10, 5))); fb.DynamicBitSlice(x, index, /*width=*/10); @@ -705,11 +705,14 @@ TEST_F(BitSliceSimplificationPassTest, DynamicBitSliceWithScaledIndex) { ASSERT_THAT( f->return_value(), - m::Select(m::ZeroExt(m::Param("z")), - {m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/10), - m::BitSlice(m::Param("x"), /*start=*/10, /*width=*/10), - m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/10)}, - m::Literal(0))); + m::Select( + m::ZeroExt(m::Param("z")), + { + m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/10), + m::BitSlice(m::Param("x"), /*start=*/10, /*width=*/10), + m::ZeroExt(m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/3)), + }, + m::Literal(0))); } TEST_F(BitSliceSimplificationPassTest, @@ -731,14 +734,16 @@ TEST_F(BitSliceSimplificationPassTest, f->return_value(), m::Select(m::BitSlice(m::UMul(m::ZeroExt(m::Param("z")), m::Literal(4)), /*start=*/2, /*width=*/3), - {m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4)})); + { + m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4), + })); } TEST_F(BitSliceSimplificationPassTest, DynamicBitSliceWithShiftedIndex) { @@ -759,14 +764,16 @@ TEST_F(BitSliceSimplificationPassTest, DynamicBitSliceWithShiftedIndex) { f->return_value(), m::Select(m::BitSlice(m::Shll(m::ZeroExt(m::Param("z")), m::Literal(2)), /*start=*/2, /*width=*/3), - {m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4)})); + { + m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4), + })); } TEST_F(BitSliceSimplificationPassTest, @@ -787,23 +794,25 @@ TEST_F(BitSliceSimplificationPassTest, ASSERT_THAT( f->return_value(), m::Select(m::Concat(m::Param("z")), - {m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4), - m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4)})); + { + m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/4, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/8, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/12, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/16, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/24, /*width=*/4), + m::BitSlice(m::Param("x"), /*start=*/28, /*width=*/4), + })); } TEST_F(BitSliceSimplificationPassTest, BitSliceUpdateWithScaledIndex) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get()); - Type* u30 = p->GetBitsType(30); + Type* u23 = p->GetBitsType(23); Type* u10 = p->GetBitsType(10); Type* u2 = p->GetBitsType(2); - BValue x = fb.Param("x", u30); + BValue x = fb.Param("x", u23); BValue y = fb.Param("y", u10); BValue z = fb.Param("z", u2); BValue index = fb.UMul(fb.ZeroExtend(z, 5), fb.Literal(UBits(10, 5))); @@ -813,15 +822,18 @@ TEST_F(BitSliceSimplificationPassTest, BitSliceUpdateWithScaledIndex) { solvers::z3::ScopedVerifyEquivalence sve(f); ASSERT_THAT(Run(f), IsOkAndHolds(true)); - auto array = m::Array(m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/10), - m::BitSlice(m::Param("x"), /*start=*/10, /*width=*/10), - m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/10)); + auto array = m::Array( + m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/10), + m::BitSlice(m::Param("x"), /*start=*/10, /*width=*/10), + m::ZeroExt(m::BitSlice(m::Param("x"), /*start=*/20, /*width=*/3))); auto array_update = m::ArrayUpdate(array, m::Param("y"), {m::ZeroExt(m::Param("z"))}); - ASSERT_THAT(f->return_value(), - m::Concat(m::ArrayIndex(array_update, {m::Literal(2)}), - m::ArrayIndex(array_update, {m::Literal(1)}), - m::ArrayIndex(array_update, {m::Literal(0)}))); + ASSERT_THAT( + f->return_value(), + m::Concat(m::BitSlice(m::ArrayIndex(array_update, {m::Literal(2)}), + /*start=*/0, /*width=*/3), + m::ArrayIndex(array_update, {m::Literal(1)}), + m::ArrayIndex(array_update, {m::Literal(0)}))); } TEST_F(BitSliceSimplificationPassTest,