Skip to content

Commit

Permalink
Introduces encode and decode as DSLX built-ins
Browse files Browse the repository at this point in the history
We also start fuzzing them, which has already uncovered a number of minor bugs that we've fixed.

PiperOrigin-RevId: 583192725
  • Loading branch information
ericastor authored and copybara-github committed Nov 17, 2023
1 parent e8ed8ef commit d324bfe
Show file tree
Hide file tree
Showing 30 changed files with 451 additions and 20 deletions.
41 changes: 41 additions & 0 deletions docs_src/dslx_std.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,47 @@ functions:
assert_eq(u32:3, x2)
```

### `decode`

Converts a binary-encoded value into a one-hot value. For an operand value of
`n``interpreted as an unsigned number, the`n`-th result bit and only the`n`-th
result bit is set. Has the following signature:

```
fn decode<uN[W]>(x: uN[N]) -> uN[W]
```

The width of the decode operation may be less than the maximum value expressible
by the input (`2**N - 1`). If the encoded operand value is larger than the
number of bits of the result, the result is zero.

Example usage:
[`dslx/tests/decode.x`](https://github.com/google/xls/tree/main/xls/dslx/tests/decode.x).

See also the
[IR semantics for the `decode` op](./ir_semantics.md#decode).

### `encode`

Converts a one-hot value to a binary-encoded value of the "hot" bit of the
input. If the `n`-th bit and only the `n`-th bit of the operand is set, the
result is equal to the value `n` as an unsigned number. Has the following
signature:

```
fn encode(x: uN[N]) -> uN[ceil(log2(N))]
```

If multiple bits of the input are set, the result is equal to the logical or of
the results produced by the input bits individually. For example, if bit 3 and
bit 5 of an encode input are set the result is equal to `3 | 5 = 7`.

Example usage:
[`dslx/tests/encode.x`](https://github.com/google/xls/tree/main/xls/dslx/tests/encode.x).

See also the
[IR semantics for the `encode` op](./ir_semantics.md#encode).

### `one_hot`

Converts a value to one-hot form. Has the following signature:
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ cc_library(
hdrs = ["interp_value.h"],
deps = [
":value_format_descriptor",
"//xls/common:math_util",
"//xls/common/logging",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/data_structures:inline_bitmap",
"//xls/dslx/frontend:ast",
"//xls/ir:bits",
"//xls/ir:bits_ops",
Expand Down Expand Up @@ -345,7 +347,9 @@ cc_library(
":interp_value",
"//xls/common/logging",
"//xls/common/status:status_macros",
"//xls/dslx/frontend:ast_node",
"//xls/dslx/frontend:builtins_metadata",
"//xls/dslx/frontend:pos",
"//xls/dslx/type_system:concrete_type",
"//xls/dslx/type_system:deduce",
"//xls/dslx/type_system:deduce_ctx",
Expand Down
12 changes: 12 additions & 0 deletions xls/dslx/bytecode/builtins.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,18 @@ absl::Status RunBuiltinGate(const Bytecode& bytecode, InterpreterStack& stack) {
stack);
}

absl::Status RunBuiltinEncode(const Bytecode& bytecode,
InterpreterStack& stack) {
XLS_VLOG(3) << "Executing builtin encode.";
XLS_RET_CHECK(!stack.empty());

XLS_ASSIGN_OR_RETURN(InterpValue input, stack.Pop());
XLS_ASSIGN_OR_RETURN(InterpValue output, input.Encode());
stack.Push(std::move(output));

return absl::OkStatus();
}

absl::Status RunBuiltinOneHot(const Bytecode& bytecode,
InterpreterStack& stack) {
return RunBinaryBuiltin(
Expand Down
3 changes: 3 additions & 0 deletions xls/dslx/bytecode/builtins.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ absl::Status RunBuiltinBitSliceUpdate(const Bytecode& bytecode,

absl::Status RunBuiltinGate(const Bytecode& bytecode, InterpreterStack& stack);

absl::Status RunBuiltinEncode(const Bytecode& bytecode,
InterpreterStack& stack);

absl::Status RunBuiltinOneHot(const Bytecode& bytecode,
InterpreterStack& stack);

Expand Down
5 changes: 5 additions & 0 deletions xls/dslx/bytecode/bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ absl::StatusOr<Bytecode::Op> OpFromString(std::string_view s) {
if (s == "create_tuple") {
return Bytecode::Op::kCreateTuple;
}
if (s == "decode") {
return Bytecode::Op::kDecode;
}
if (s == "div") {
return Bytecode::Op::kDiv;
}
Expand Down Expand Up @@ -222,6 +225,8 @@ std::string OpToString(Bytecode::Op op) {
return "create_array";
case Bytecode::Op::kCreateTuple:
return "create_tuple";
case Bytecode::Op::kDecode:
return "decode";
case Bytecode::Op::kDiv:
return "div";
case Bytecode::Op::kDup:
Expand Down
3 changes: 3 additions & 0 deletions xls/dslx/bytecode/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class Bytecode {
// Creates an N-tuple (N given in the data argument) from the values on the
// stack.
kCreateTuple,
// Decodes the element on top of the stack to a one-hot of the type given as
// the parametric arg.
kDecode,
// Divides the N-1th value on the stack by the Nth value.
kDiv,
// Determines remainder of division of the N-1th value by the Nth value.
Expand Down
19 changes: 19 additions & 0 deletions xls/dslx/bytecode/bytecode_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,21 @@ absl::Status BytecodeEmitter::HandleBuiltinSendIf(const Invocation* node) {
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleBuiltinDecode(const Invocation* node) {
XLS_VLOG(5) << "BytecodeEmitter::HandleInvocation - Decode @ "
<< node->span();

const Expr* from_expr = node->args().at(0);
XLS_RETURN_IF_ERROR(from_expr->AcceptExpr(this));

XLS_ASSIGN_OR_RETURN(BitsType * to, GetTypeOfNodeAsBits(node, type_info_));

bytecode_.push_back(
Bytecode(node->span(), Bytecode::Op::kDecode, to->CloneToUnique()));

return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleBuiltinCheckedCast(const Invocation* node) {
XLS_VLOG(5) << "BytecodeEmitter::HandleInvocation - CheckedCast @ "
<< node->span();
Expand Down Expand Up @@ -1171,6 +1186,10 @@ absl::Status BytecodeEmitter::HandleInvocation(const Invocation* node) {
return absl::OkStatus();
}

if (name_ref->identifier() == "decode") {
return HandleBuiltinDecode(node);
}

if (name_ref->identifier() == "widening_cast") {
return HandleBuiltinWideningCast(node);
}
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/bytecode/bytecode_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class BytecodeEmitter : public ExprVisitor {
absl::Status HandleUnrollFor(const UnrollFor* node) override;
absl::Status HandleXlsTuple(const XlsTuple* node) override;

absl::Status HandleBuiltinDecode(const Invocation* node);
absl::Status HandleBuiltinCheckedCast(const Invocation* node);
absl::Status HandleBuiltinWideningCast(const Invocation* node);
absl::Status HandleBuiltinSend(const Invocation* node);
Expand Down
34 changes: 34 additions & 0 deletions xls/dslx/bytecode/bytecode_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ absl::Status BytecodeInterpreter::EvalNextInstruction() {
XLS_RETURN_IF_ERROR(EvalCreateTuple(bytecode));
break;
}
case Bytecode::Op::kDecode: {
XLS_RETURN_IF_ERROR(EvalDecode(bytecode));
break;
}
case Bytecode::Op::kDiv: {
XLS_RETURN_IF_ERROR(EvalDiv(bytecode));
break;
Expand Down Expand Up @@ -671,6 +675,33 @@ absl::Status BytecodeInterpreter::EvalCreateTuple(const Bytecode& bytecode) {
return absl::OkStatus();
}

absl::Status BytecodeInterpreter::EvalDecode(const Bytecode& bytecode) {
if (!bytecode.data().has_value() ||
!std::holds_alternative<std::unique_ptr<ConcreteType>>(
bytecode.data().value())) {
return absl::InternalError("Decode op requires ConcreteType data.");
}

XLS_ASSIGN_OR_RETURN(InterpValue from, Pop());
if (!from.IsBits() || from.IsSigned()) {
return absl::InvalidArgumentError(absl::StrCat(
"Decode op requires UBits-type input, was: ", from.ToString()));
}

ConcreteType* to =
std::get<std::unique_ptr<ConcreteType>>(bytecode.data().value()).get();
if (!IsUBits(*to)) {
return absl::InvalidArgumentError(absl::StrCat(
"Decode op requires UBits-type output, was: ", to->ToString()));
}
BitsType* to_bits = dynamic_cast<BitsType*>(to);
XLS_ASSIGN_OR_RETURN(int64_t new_bit_count, to_bits->size().GetAsInt64());

XLS_ASSIGN_OR_RETURN(InterpValue decoded, from.Decode(new_bit_count));
stack_.Push(std::move(decoded));
return absl::OkStatus();
}

absl::Status BytecodeInterpreter::EvalDiv(const Bytecode& bytecode) {
return EvalBinop([](const InterpValue& lhs, const InterpValue& rhs) {
return lhs.FloorDiv(rhs);
Expand Down Expand Up @@ -1334,6 +1365,8 @@ absl::Status BytecodeInterpreter::RunBuiltinFn(const Bytecode& bytecode,
return RunBuiltinGate(bytecode, stack_);
case Builtin::kMap:
return RunBuiltinMap(bytecode);
case Builtin::kEncode:
return RunBuiltinEncode(bytecode, stack_);
case Builtin::kOneHot:
return RunBuiltinOneHot(bytecode, stack_);
case Builtin::kOneHotSel:
Expand Down Expand Up @@ -1374,6 +1407,7 @@ absl::Status BytecodeInterpreter::RunBuiltinFn(const Bytecode& bytecode,
case Builtin::kRecvIf:
case Builtin::kRecvNonBlocking:
case Builtin::kRecvIfNonBlocking:
case Builtin::kDecode:
case Builtin::kCheckedCast:
case Builtin::kWideningCast:
case Builtin::kSelect:
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/bytecode/bytecode_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class BytecodeInterpreter {
absl::Status EvalConcat(const Bytecode& bytecode);
absl::Status EvalCreateArray(const Bytecode& bytecode);
absl::Status EvalCreateTuple(const Bytecode& bytecode);
absl::Status EvalDecode(const Bytecode& bytecode);
absl::Status EvalDiv(const Bytecode& bytecode);
absl::Status EvalDup(const Bytecode& bytecode);
absl::Status EvalEq(const Bytecode& bytecode);
Expand Down
48 changes: 48 additions & 0 deletions xls/dslx/dslx_builtins.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
#include "xls/common/logging/logging.h"
#include "xls/common/status/status_macros.h"
#include "xls/dslx/errors.h"
#include "xls/dslx/frontend/ast_node.h"
#include "xls/dslx/frontend/builtins_metadata.h"
#include "xls/dslx/frontend/pos.h"
#include "xls/dslx/interp_value.h"
#include "xls/dslx/type_system/concrete_type.h"
#include "xls/dslx/type_system/deduce.h"
Expand Down Expand Up @@ -775,6 +777,52 @@ void PopulateSignatureToLambdaMap(
return TypeAndParametricEnv{std::make_unique<FunctionType>(
CloneToUnique(data.arg_types), data.arg_types[1]->CloneToUnique())};
};
map["<uN[M]>(uN[N]) -> uN[M]"] =
[](const SignatureData& data,
DeduceCtx* ctx) -> absl::StatusOr<TypeAndParametricEnv> {
XLS_RETURN_IF_ERROR(Checker(data.arg_types, data.name, data.span, *ctx)
.Len(1)
.IsUN(0)
.status());

if (data.arg_explicit_parametrics.size() != 1) {
return ArgCountMismatchErrorStatus(
data.span,
absl::StrFormat("Invalid number of parametrics passed to '%s', "
"expected 1, got %d",
data.name, data.arg_explicit_parametrics.size()));
}
AstNode* param_type = ToAstNode(data.arg_explicit_parametrics.at(0));
XLS_ASSIGN_OR_RETURN(std::unique_ptr<ConcreteType> return_type,
DeduceAndResolve(param_type, ctx));
XLS_ASSIGN_OR_RETURN(return_type, UnwrapMetaType(std::move(return_type),
data.span, data.name));
if (auto* a = dynamic_cast<const BitsType*>(return_type.get());
a == nullptr || a->is_signed()) {
return TypeInferenceErrorStatus(
param_type->GetSpan().value_or(FakeSpan()), return_type.get(),
absl::StrFormat("Want return type to be unsigned bits; got %s",
return_type->ToString()));
}

return TypeAndParametricEnv{std::make_unique<FunctionType>(
CloneToUnique(data.arg_types), std::move(return_type))};
};
map["(uN[N]) -> uN[ceil(log2(N))]"] =
[](const SignatureData& data,
DeduceCtx* ctx) -> absl::StatusOr<TypeAndParametricEnv> {
XLS_RETURN_IF_ERROR(Checker(data.arg_types, data.name, data.span, *ctx)
.Len(1)
.IsUN(0)
.status());
XLS_ASSIGN_OR_RETURN(ConcreteTypeDim n,
data.arg_types[0]->GetTotalBitCount());
XLS_ASSIGN_OR_RETURN(ConcreteTypeDim log2_n, n.CeilOfLog2());
auto return_type =
std::make_unique<BitsType>(/*signed=*/false, /*size=*/log2_n);
return TypeAndParametricEnv{std::make_unique<FunctionType>(
CloneToUnique(data.arg_types), std::move(return_type))};
};
map["(uN[N], u1) -> uN[N+1]"] =
[](const SignatureData& data,
DeduceCtx* ctx) -> absl::StatusOr<TypeAndParametricEnv> {
Expand Down
2 changes: 2 additions & 0 deletions xls/dslx/frontend/builtins_metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ const absl::flat_hash_map<std::string, BuiltinsData>& GetParametricBuiltins() {
{"fail!", {"(u8[N], T) -> T", false}},
{"gate!", {"(u1, T) -> T", false}},
{"map", {"(T[N], (T) -> U) -> U[N]", false}},
{"decode", {"<uN[M]>(uN[N]) -> uN[M]", false}},
{"encode", {"(uN[N]) -> uN[ceil(log2(N))]", false}},
{"one_hot", {"(uN[N], u1) -> uN[N+1]", false}},
{"one_hot_sel", {"(xN[N], xN[M][N]) -> xN[M]", false}},
{"priority_sel", {"(xN[N], xN[M][N]) -> xN[M]", false}},
Expand Down
47 changes: 47 additions & 0 deletions xls/dslx/interp_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <algorithm>
#include <cstdint>
#include <limits>
#include <memory>
#include <optional>
#include <string>
Expand All @@ -31,8 +32,10 @@
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xls/common/logging/logging.h"
#include "xls/common/math_util.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/data_structures/inline_bitmap.h"
#include "xls/ir/bits.h"
#include "xls/ir/bits_ops.h"
#include "xls/ir/format_preference.h"
Expand Down Expand Up @@ -645,6 +648,17 @@ absl::StatusOr<InterpValue> InterpValue::ArithmeticNegate() const {
return InterpValue(tag_, bits_ops::Negate(arg));
}

absl::StatusOr<InterpValue> InterpValue::CeilOfLog2() const {
XLS_ASSIGN_OR_RETURN(Bits arg, GetBits());
if (arg.IsZero()) {
return InterpValue(tag_, UBits(0, 32));
}
// Subtract one to make sure we get the right result for exact powers of 2.
int64_t min_bit_width =
arg.bit_count() - bits_ops::Decrement(arg).CountLeadingZeros();
return InterpValue(tag_, UBits(min_bit_width, 32));
}

absl::StatusOr<Bits> InterpValue::GetBits() const {
if (std::holds_alternative<Bits>(payload_)) {
return std::get<Bits>(payload_);
Expand Down Expand Up @@ -713,6 +727,39 @@ absl::StatusOr<InterpValue> InterpValue::ZeroExt(int64_t new_bit_count) const {
return InterpValue(new_tag, bits_ops::ZeroExtend(b, new_bit_count));
}

absl::StatusOr<InterpValue> InterpValue::Decode(int64_t new_bit_count) const {
XLS_ASSIGN_OR_RETURN(Bits arg, GetBits());

absl::StatusOr<uint64_t> unsigned_index = arg.ToUint64();
if (!unsigned_index.ok() ||
*unsigned_index >
static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
// Index cannot be represented in a 64-bit signed integer - so it's telling
// us to set a bit that's definitely out of range. Return 0.
return InterpValue(InterpValueTag::kUBits, Bits(new_bit_count));
}

const int64_t index = static_cast<int64_t>(*unsigned_index);
InlineBitmap result(new_bit_count);
if (index < new_bit_count) {
result.Set(index);
}
return InterpValue(InterpValueTag::kUBits,
Bits::FromBitmap(std::move(result)));
}

absl::StatusOr<InterpValue> InterpValue::Encode() const {
XLS_ASSIGN_OR_RETURN(Bits arg, GetBits());
int64_t result = 0;
for (int64_t i = 0; i < arg.bit_count(); ++i) {
if (arg.Get(i)) {
result |= i;
}
}
return InterpValue(InterpValueTag::kUBits,
UBits(result, ::xls::CeilOfLog2(arg.bit_count())));
}

absl::StatusOr<InterpValue> InterpValue::OneHot(bool lsb_prio) const {
XLS_ASSIGN_OR_RETURN(Bits arg, GetBits());
if (lsb_prio) {
Expand Down
Loading

0 comments on commit d324bfe

Please sign in to comment.