Skip to content

Commit

Permalink
Merge pull request google#1818 from xlsynth:cdleary/2024-12-30-negati…
Browse files Browse the repository at this point in the history
…ve-size-slice

PiperOrigin-RevId: 711527414
  • Loading branch information
copybara-github committed Jan 2, 2025
2 parents 4364059 + 5f9d7a8 commit be6d8f4
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 120 deletions.
3 changes: 2 additions & 1 deletion xls/dslx/bytecode/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class Bytecode {
// otherwise it'll be logical.
kShr,
// Slices out a subset of the bits-typed value on TOS2,
// starting at index TOS1 and ending at index TOS0.
// starting at index TOS1 with bitwidth at TOS0.
// Note: the start index and the bitwidth should both be non-negative.
kSlice,
// Creates a new proc interpreter using the data in the optional data member
// (as a `SpawnData`).
Expand Down
122 changes: 53 additions & 69 deletions xls/dslx/bytecode/bytecode_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -979,88 +979,72 @@ absl::Status BytecodeEmitter::HandleFormatMacro(const FormatMacro* node) {
return absl::OkStatus();
}

static absl::StatusOr<int64_t> GetValueWidth(const TypeInfo* type_info,
Expr* expr) {
std::optional<Type*> maybe_type = type_info->GetItem(expr);
absl::Status BytecodeEmitter::HandleSlice(const Index* node, Slice* slice) {
std::optional<StartAndWidth> saw = type_info_->GetSliceStartAndWidth(
slice,
caller_bindings_.has_value() ? *caller_bindings_ : ParametricEnv());
if (!saw.has_value()) {
return absl::InternalError(absl::StrFormat(
"Expected start-and-width data for slice `%s` @ %s to be populated "
"from type checking.",
slice->ToString(), node->span().ToString(file_table())));
}

XLS_RET_CHECK_GE(saw->start, 0);
XLS_RET_CHECK_GE(saw->width, 0);

// Helper for either getting the span of the given slice index or, if that
// slice index is nullptr, getting the span from the index operation as a
// fallback.
auto span_or_default = [&](Expr* slice_index) -> Span {
if (slice_index != nullptr) {
return slice_index->span();
}
return node->span();
};

bytecode_.push_back(Bytecode(span_or_default(slice->start()),
Bytecode::Op::kLiteral,
InterpValue::MakeU32(saw->start)));
bytecode_.push_back(Bytecode(span_or_default(slice->limit()),
Bytecode::Op::kLiteral,
InterpValue::MakeU32(saw->width)));

bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kSlice));
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleWidthSlice(const Index* node,
WidthSlice* width_slice) {
XLS_RETURN_IF_ERROR(width_slice->start()->AcceptExpr(this));

std::optional<Type*> maybe_type = type_info_->GetItem(width_slice->width());
if (!maybe_type.has_value()) {
return absl::InternalError(
"Could not find concrete type for slice component.");
return absl::InternalError(absl::StrCat(
"Could not find concrete type for slice width parameter \"",
width_slice->width()->ToString(), "\"."));
}
return maybe_type.value()->GetTotalBitCount()->GetAsInt64();

MetaType* type = dynamic_cast<MetaType*>(maybe_type.value());
XLS_RET_CHECK(type != nullptr) << maybe_type.value()->ToString();
XLS_RET_CHECK(IsBitsLike(*type->wrapped())) << type->ToString();

bytecode_.push_back(
Bytecode(node->span(), Bytecode::Op::kWidthSlice, type->CloneToUnique()));
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleIndex(const Index* node) {
XLS_RETURN_IF_ERROR(node->lhs()->AcceptExpr(this));

if (std::holds_alternative<Slice*>(node->rhs())) {
Slice* slice = std::get<Slice*>(node->rhs());
if (slice->start() == nullptr) {
int64_t start_width;
if (slice->limit() == nullptr) {
// TODO(rspringer): Define a uniform `usize` to avoid specifying magic
// numbers here. This is the default size used for untyped numbers in
// the typechecker.
start_width = 32;
} else {
XLS_ASSIGN_OR_RETURN(start_width,
GetValueWidth(type_info_, slice->limit()));
}
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kLiteral,
InterpValue::MakeSBits(start_width, 0)));
} else {
XLS_RETURN_IF_ERROR(slice->start()->AcceptExpr(this));
}

if (slice->limit() == nullptr) {
std::optional<Type*> maybe_type = type_info_->GetItem(node->lhs());
if (!maybe_type.has_value()) {
return absl::InternalError("Could not find concrete type for slice.");
}
Type* type = maybe_type.value();
// These will never fail.
absl::StatusOr<TypeDim> dim = type->GetTotalBitCount();
CHECK_OK(dim);
absl::StatusOr<int64_t> width = dim->GetAsInt64();
CHECK_OK(width);

int64_t limit_width;
if (slice->start() == nullptr) {
// TODO(rspringer): Define a uniform `usize` to avoid specifying magic
// numbers here. This is the default size used for untyped numbers in
// the typechecker.
limit_width = 32;
} else {
XLS_ASSIGN_OR_RETURN(limit_width,
GetValueWidth(type_info_, slice->start()));
}
bytecode_.push_back(
Bytecode(node->span(), Bytecode::Op::kLiteral,
InterpValue::MakeSBits(limit_width, *width)));
} else {
XLS_RETURN_IF_ERROR(slice->limit()->AcceptExpr(this));
}
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kSlice));
return absl::OkStatus();
return HandleSlice(node, slice);
}

if (std::holds_alternative<WidthSlice*>(node->rhs())) {
WidthSlice* width_slice = std::get<WidthSlice*>(node->rhs());
XLS_RETURN_IF_ERROR(width_slice->start()->AcceptExpr(this));

std::optional<Type*> maybe_type = type_info_->GetItem(width_slice->width());
if (!maybe_type.has_value()) {
return absl::InternalError(absl::StrCat(
"Could not find concrete type for slice width parameter \"",
width_slice->width()->ToString(), "\"."));
}

MetaType* type = dynamic_cast<MetaType*>(maybe_type.value());
XLS_RET_CHECK(type != nullptr) << maybe_type.value()->ToString();
XLS_RET_CHECK(IsBitsLike(*type->wrapped())) << type->ToString();

bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kWidthSlice,
type->CloneToUnique()));
return absl::OkStatus();
return HandleWidthSlice(node, width_slice);
}

// Otherwise, it's a regular [array or tuple] index op.
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/bytecode/bytecode_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ class BytecodeEmitter : public ExprVisitor {
absl::Status HandleFunctionRef(const FunctionRef* node) override;
absl::Status HandleZeroMacro(const ZeroMacro* node) override;
absl::Status HandleAllOnesMacro(const AllOnesMacro* node) override;

absl::Status HandleIndex(const Index* node) override;
absl::Status HandleSlice(const Index* node, Slice* slice);
absl::Status HandleWidthSlice(const Index* node, WidthSlice* width_slice);

absl::Status HandleInvocation(const Invocation* node) override;
absl::Status HandleLet(const Let* node) override;
absl::Status HandleMatch(const Match* node) override;
Expand Down
53 changes: 7 additions & 46 deletions xls/dslx/bytecode/bytecode_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1253,54 +1253,15 @@ absl::Status BytecodeInterpreter::EvalShr(const Bytecode& bytecode) {
}

absl::Status BytecodeInterpreter::EvalSlice(const Bytecode& bytecode) {
XLS_ASSIGN_OR_RETURN(InterpValue limit, Pop());
XLS_ASSIGN_OR_RETURN(InterpValue length, Pop());
XLS_ASSIGN_OR_RETURN(InterpValue start, Pop());
XLS_ASSIGN_OR_RETURN(InterpValue basis, Pop());
XLS_ASSIGN_OR_RETURN(int64_t basis_bit_count, basis.GetBitCount());
XLS_ASSIGN_OR_RETURN(int64_t start_bit_count, start.GetBitCount());

InterpValue zero = InterpValue::MakeSBits(start_bit_count, 0);
InterpValue basis_length =
InterpValue::MakeSBits(start_bit_count, basis_bit_count);

XLS_ASSIGN_OR_RETURN(InterpValue start_lt_zero, start.Lt(zero));
if (start_lt_zero.IsTrue()) {
// Remember, start is negative if we're here.
XLS_ASSIGN_OR_RETURN(start, basis_length.Add(start));
// If start is _still_ less than zero, then we clamp to zero.
XLS_ASSIGN_OR_RETURN(start_lt_zero, start.Lt(zero));
if (start_lt_zero.IsTrue()) {
start = zero;
}
}

XLS_ASSIGN_OR_RETURN(InterpValue limit_lt_zero, limit.Lt(zero));
if (limit_lt_zero.IsTrue()) {
// Ditto.
XLS_ASSIGN_OR_RETURN(limit, basis_length.Add(limit));
XLS_ASSIGN_OR_RETURN(limit_lt_zero, limit.Lt(zero));
if (limit_lt_zero.IsTrue()) {
limit = zero;
}
}

// If limit extends past the basis, then we truncate limit.
XLS_ASSIGN_OR_RETURN(InterpValue limit_ge_basis_length,
limit.Ge(basis_length));
if (limit_ge_basis_length.IsTrue()) {
limit =
InterpValue::MakeSBits(start_bit_count, basis.GetBitCount().value());
}
XLS_ASSIGN_OR_RETURN(InterpValue length, limit.Sub(start));

// At this point, both start and length must be nonnegative, so we force them
// to UBits, since Slice expects that.
XLS_ASSIGN_OR_RETURN(int64_t start_value, start.GetBitValueViaSign());
XLS_ASSIGN_OR_RETURN(int64_t length_value, length.GetBitValueViaSign());
XLS_RET_CHECK_GE(start_value, 0);
XLS_RET_CHECK_GE(length_value, 0);
start = InterpValue::MakeBits(/*is_signed=*/false, start.GetBitsOrDie());
length = InterpValue::MakeBits(/*is_signed=*/false, length.GetBitsOrDie());
XLS_RET_CHECK(length.IsUBits())
<< "Slice length is not unsigned bits: " << length.ToString();
XLS_RET_CHECK(start.IsUBits())
<< "Slice start is not unsigned bits: " << start.ToString();
XLS_RET_CHECK(basis.IsUBits())
<< "Slice basis is not unsigned bits: " << basis.ToString();
XLS_ASSIGN_OR_RETURN(InterpValue result, basis.Slice(start, length));
stack_.Push(result);
return absl::OkStatus();
Expand Down
22 changes: 22 additions & 0 deletions xls/dslx/bytecode/bytecode_interpreter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ TEST_F(BytecodeInterpreterTest, DupLiteral) {
EXPECT_EQ(result.ToString(), "u32:42");
}

// Note: this test case spews a stack trace as the `!stack_.empty()` error comes
// from a `XLS_RET_CHECK`, this is really an internals-flags-an-error test not a
// behavioral test.
TEST_F(BytecodeInterpreterTest, DupEmptyStack) {
std::vector<Bytecode> bytecodes;
bytecodes.emplace_back(kFakeSpan, Bytecode::Op::kDup);
Expand Down Expand Up @@ -1318,6 +1321,25 @@ fn negative_end_slice() -> u16 {
EXPECT_EQ(int_value, 0xbeef);
}

// https://github.com/google/xls/issues/1784 -- note that the size of a slice
// can never be negative, but the bytecode interpreter can have inflated
// expectations that don't line up with what the type checker will accept. This
// test is to ensure that we don't crash when we encounter this case.
TEST_F(BytecodeInterpreterTest, NegativeSizeSlice) {
constexpr std::string_view kProgram = R"(
fn negative_size_slice() -> bits[0] {
(u32:0x42)[5:3]
}
)";

XLS_ASSERT_OK_AND_ASSIGN(InterpValue value,
Interpret(kProgram, "negative_size_slice"));
ASSERT_TRUE(value.IsUBits());
const Bits& bits = value.GetBitsOrDie();
EXPECT_EQ(bits.bit_count(), 0);
EXPECT_EQ(bits.ToUint64().value(), 0);
}

TEST_F(BytecodeInterpreterTest, WidthSlice) {
constexpr std::string_view kProgram = R"(
fn width_slice() -> s16 {
Expand Down
11 changes: 7 additions & 4 deletions xls/dslx/interp_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,11 @@ bool InterpValue::operator==(const InterpValue& rhs) const { return Eq(rhs); }
const InterpValue& lhs, const InterpValue& rhs, CompareF ucmp,
CompareF scmp) {
if (lhs.tag_ != rhs.tag_) {
return absl::InvalidArgumentError(absl::StrFormat(
"Same tag is required for a comparison operation: lhs %s rhs %s",
TagToString(lhs.tag_), TagToString(rhs.tag_)));
return absl::InvalidArgumentError(
absl::StrFormat("Same tag is required for a comparison operation: lhs "
"tag: %s, rhs tag: %s, lhs value: %s, rhs value: %s",
TagToString(lhs.tag_), TagToString(rhs.tag_),
lhs.ToString(), rhs.ToString()));
}
switch (lhs.tag_) {
case InterpValueTag::kUBits:
Expand Down Expand Up @@ -692,7 +694,8 @@ absl::StatusOr<Bits> InterpValue::GetBits() const {
return std::get<EnumData>(payload_).value;
}

return absl::InvalidArgumentError("Value does not contain bits.");
return absl::InvalidArgumentError(
absl::StrFormat("Value %s does not contain bits.", ToString()));
}

const Bits& InterpValue::GetBitsOrDie() const {
Expand Down

0 comments on commit be6d8f4

Please sign in to comment.