Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions runtime-light/stdlib/rpc/rpc-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ inline bool f$store_string(const string& v) noexcept {
}

inline bool f$store_string2(const string& v) noexcept {
tl::string{.value = {v.c_str(), v.size()}}.store2(RpcServerInstanceState::get().tl_storer);
tl2::string{.value = {v.c_str(), v.size()}}.store(RpcServerInstanceState::get().tl_storer);
return true;
}

Expand Down Expand Up @@ -175,7 +175,7 @@ inline string f$fetch_string() noexcept {
}

inline string f$fetch_string2() noexcept {
if (tl::string val{}; val.fetch2(RpcServerInstanceState::get().tl_fetcher)) [[likely]] {
if (tl2::string val{}; val.fetch(RpcServerInstanceState::get().tl_fetcher)) [[likely]] {
return {val.value.data(), static_cast<string::size_type>(val.value.size())};
}
THROW_EXCEPTION(kphp::rpc::exception::cant_fetch_string::make());
Expand Down
24 changes: 12 additions & 12 deletions runtime-light/tl/tl-functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ bool K2InvokeJobWorker::fetch(tl::fetcher& tlf) noexcept {
tl::magic magic{};
tl::mask flags{};
bool ok{magic.fetch(tlf) && magic.expect(K2_INVOKE_JOB_WORKER_MAGIC)};
ok &= flags.fetch(tlf);
ok &= image_id.fetch(tlf);
ok &= job_id.fetch(tlf);
ok &= timeout_ns.fetch(tlf);
ok &= body.fetch(tlf);
ok = ok && flags.fetch(tlf);
ok = ok && image_id.fetch(tlf);
ok = ok && job_id.fetch(tlf);
ok = ok && timeout_ns.fetch(tlf);
ok = ok && body.fetch(tlf);
ignore_answer = static_cast<bool>(flags.value & IGNORE_ANSWER_FLAG);
return ok;
}
Expand All @@ -42,14 +42,14 @@ bool K2InvokeHttp::fetch(tl::fetcher& tlf) noexcept {
tl::magic magic{};
tl::mask flags{};
bool ok{magic.fetch(tlf) && magic.expect(K2_INVOKE_HTTP_MAGIC)};
ok &= flags.fetch(tlf);
ok &= connection.fetch(tlf);
ok &= version.fetch(tlf);
ok &= method.fetch(tlf);
ok &= uri.fetch(tlf);
ok &= headers.fetch(tlf);
ok = ok && flags.fetch(tlf);
ok = ok && connection.fetch(tlf);
ok = ok && version.fetch(tlf);
ok = ok && method.fetch(tlf);
ok = ok && uri.fetch(tlf);
ok = ok && headers.fetch(tlf);
const auto opt_body{tlf.fetch_bytes(tlf.remaining())};
ok &= opt_body.has_value();
ok = ok && opt_body.has_value();

body = opt_body.value_or(std::span<const std::byte>{});

Expand Down
10 changes: 5 additions & 5 deletions runtime-light/tl/tl-functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,14 @@ class K2InvokeRpc final {
bool fetch(tl::fetcher& tlf) noexcept {
tl::magic magic{};
bool ok{magic.fetch(tlf) && magic.expect(K2_INVOKE_RPC_MAGIC)};
ok &= flags.fetch(tlf);
ok &= query_id.fetch(tlf);
ok &= net_pid.fetch(tlf);
ok = ok && flags.fetch(tlf);
ok = ok && query_id.fetch(tlf);
ok = ok && net_pid.fetch(tlf);
if (static_cast<bool>(flags.value & ACTOR_ID_FLAG)) {
ok &= opt_actor_id.emplace().fetch(tlf);
ok = ok && opt_actor_id.emplace().fetch(tlf);
}
if (static_cast<bool>(flags.value & EXTRA_FLAG)) {
ok &= opt_extra.emplace().fetch(tlf);
ok = ok && opt_extra.emplace().fetch(tlf);
}
const auto opt_query{tlf.fetch_bytes(tlf.remaining())};
query = opt_query.value_or(std::span<const std::byte>{});
Expand Down
214 changes: 97 additions & 117 deletions runtime-light/tl/tl-types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <array>
#include <cstddef>
#include <cstdint>
#include <span>
#include <utility>

#include "runtime-light/stdlib/diagnostics/logs.h"
Expand All @@ -25,22 +26,19 @@ bool string::fetch(tl::fetcher& tlf) noexcept {
uint8_t size_len{};
uint64_t string_len{};
switch (first_byte) {
case LARGE_STRING_MAGIC: {
if (tlf.remaining() < LARGE_STRING_SIZE_LEN) [[unlikely]] {
case HUGE_STRING_MAGIC: {
if (tlf.remaining() < HUGE_STRING_SIZE_LEN) [[unlikely]] {
return false;
}
size_len = LARGE_STRING_SIZE_LEN + 1;
const auto first{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>())};
const auto second{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 8};
const auto third{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 16};
const auto fourth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 24};
const auto fifth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 32};
const auto sixth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 40};
const auto seventh{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 48};
string_len = first | second | third | fourth | fifth | sixth | seventh;
size_len = HUGE_STRING_SIZE_LEN + 1;
auto len_bytes{*tlf.fetch_bytes(HUGE_STRING_SIZE_LEN)};
string_len = static_cast<uint64_t>(len_bytes[0]) | (static_cast<uint64_t>(len_bytes[1]) << 8) | (static_cast<uint64_t>(len_bytes[2]) << 16) |
(static_cast<uint64_t>(len_bytes[3]) << 24) | (static_cast<uint64_t>(len_bytes[4]) << 32) | (static_cast<uint64_t>(len_bytes[5]) << 40) |
(static_cast<uint64_t>(len_bytes[6]) << 48);

if (string_len <= MEDIUM_STRING_MAX_LEN) [[unlikely]] {
kphp::log::warning("large string's length is less than (1 << 24) - 1 (length = {})", string_len);
return false;
}
break;
}
Expand All @@ -49,18 +47,17 @@ bool string::fetch(tl::fetcher& tlf) noexcept {
return false;
}
size_len = MEDIUM_STRING_SIZE_LEN + 1;
const auto first{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>())};
const auto second{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 8};
const auto third{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 16};
string_len = first | second | third;
auto len_bytes{*tlf.fetch_bytes(MEDIUM_STRING_SIZE_LEN)};
string_len = static_cast<uint64_t>(len_bytes[0]) | (static_cast<uint64_t>(len_bytes[1]) << 8) | (static_cast<uint64_t>(len_bytes[2]) << 16);

if (string_len <= SMALL_STRING_MAX_LEN) [[unlikely]] {
if (string_len <= TINY_STRING_MAX_LEN) [[unlikely]] {
kphp::log::warning("long string's length is less than 254 (length = {})", string_len);
return false;
}
break;
}
default: {
size_len = SMALL_STRING_SIZE_LEN;
size_len = TINY_STRING_SIZE_LEN;
string_len = static_cast<uint64_t>(first_byte);
break;
}
Expand All @@ -83,20 +80,27 @@ void string::store(tl::storer& tls) const noexcept {
const char* str_buf{value.data()};
size_t str_len{value.size()};
uint8_t size_len{};
if (str_len <= SMALL_STRING_MAX_LEN) {
size_len = SMALL_STRING_SIZE_LEN;
if (str_len <= TINY_STRING_MAX_LEN) {
size_len = TINY_STRING_SIZE_LEN;
tls.store_trivial<uint8_t>(str_len);
} else if (str_len <= MEDIUM_STRING_MAX_LEN) {
size_len = MEDIUM_STRING_SIZE_LEN + 1;
tls.store_trivial<uint8_t>(MEDIUM_STRING_MAGIC);
tls.store_trivial<uint8_t>(str_len & 0xff);
tls.store_trivial<uint8_t>((str_len >> 8) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 16) & 0xff);
std::array<std::byte, MEDIUM_STRING_SIZE_LEN> len_bytes{static_cast<std::byte>(str_len & 0xff), static_cast<std::byte>((str_len >> 8) & 0xff),
static_cast<std::byte>((str_len >> 16) & 0xff)};
tls.store_bytes(len_bytes);
} else if (str_len <= HUGE_STRING_MAX_LEN) {
size_len = HUGE_STRING_SIZE_LEN + 1;
tls.store_trivial<uint8_t>(HUGE_STRING_MAGIC);
std::array<std::byte, HUGE_STRING_SIZE_LEN> len_bytes{static_cast<std::byte>(str_len & 0xff), static_cast<std::byte>((str_len >> 8) & 0xff),
static_cast<std::byte>((str_len >> 16) & 0xff), static_cast<std::byte>((str_len >> 24) & 0xff),
static_cast<std::byte>((str_len >> 32) & 0xff), static_cast<std::byte>((str_len >> 40) & 0xff),
static_cast<std::byte>((str_len >> 48) & 0xff)};
tls.store_bytes(len_bytes);
} else {
kphp::log::warning("large strings aren't supported");
size_len = SMALL_STRING_SIZE_LEN;
kphp::log::warning("string length exceeds maximum allowed length: max allowed -> {}, actual -> {}", HUGE_STRING_MAX_LEN, str_len);
size_len = 0;
str_len = 0;
tls.store_trivial<uint8_t>(str_len);
}
tls.store_bytes({reinterpret_cast<const std::byte*>(str_buf), str_len});

Expand All @@ -108,90 +112,6 @@ void string::store(tl::storer& tls) const noexcept {
tls.store_bytes({reinterpret_cast<const std::byte*>(padding_array.data()), padding});
}

bool string::fetch2_len(tl::fetcher& tlf, uint64_t& string_len) noexcept {
uint8_t first_byte{};
if (const auto opt_first_byte{tlf.fetch_trivial<uint8_t>()}; opt_first_byte) [[likely]] {
first_byte = *opt_first_byte;
} else {
return false;
}

switch (first_byte) {
case LARGE_STRING_MAGIC: {
if (tlf.remaining() < 8) [[unlikely]] {
return false;
}
const auto first{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>())};
const auto second{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 8};
const auto third{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 16};
const auto fourth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 24};
const auto fifth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 32};
const auto sixth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 40};
const auto seventh{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 48};
const auto eighth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 56};
string_len = first | second | third | fourth | fifth | sixth | seventh | eighth;
// we allow non-canonical length to speed up some rare implementations
return true;
}
case MEDIUM_STRING_MAGIC: {
if (tlf.remaining() < 2) [[unlikely]] {
return false;
}
const auto first{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>())};
const auto second{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 8};
string_len = MEDIUM_STRING_MAGIC + (first | second);
return true;
}
default: {
string_len = static_cast<uint64_t>(first_byte);
return true;
}
}
}

bool string::fetch2(tl::fetcher& tlf) noexcept {
uint64_t string_len{};
if (!string::fetch2_len(tlf, string_len)) {
return false;
}
if (tlf.remaining() < string_len) [[unlikely]] {
return false;
}

value = {reinterpret_cast<const char*>(std::next(tlf.view().data(), tlf.pos())), static_cast<size_t>(string_len)};
tlf.adjust(string_len);
return true;
}

void string::store2_len(tl::storer& tls, uint64_t str_len) noexcept {
if (str_len < MEDIUM_STRING_MAGIC) {
tls.store_trivial<uint8_t>(str_len);
return;
}
if (str_len < MEDIUM_STRING_MAGIC + static_cast<uint64_t>(1 << 16)) {
str_len -= MEDIUM_STRING_MAGIC;
tls.store_trivial<uint8_t>(MEDIUM_STRING_MAGIC);
tls.store_trivial<uint8_t>(str_len & 0xff);
tls.store_trivial<uint8_t>((str_len >> 8) & 0xff);
return;
}
tls.store_trivial<uint8_t>(LARGE_STRING_MAGIC);
tls.store_trivial<uint8_t>(str_len & 0xff);
tls.store_trivial<uint8_t>((str_len >> 8) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 16) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 24) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 32) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 40) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 48) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 56) & 0xff);
}

void string::store2(tl::storer& tls) const noexcept {
uint64_t str_len = value.size();
string::store2_len(tls, str_len);
tls.store_bytes({reinterpret_cast<const std::byte*>(value.data()), str_len});
}

bool CertInfoItem::fetch(tl::fetcher& tlf) noexcept {
tl::magic magic{};
if (!magic.fetch(tlf)) [[unlikely]] {
Expand Down Expand Up @@ -230,28 +150,28 @@ bool CertInfoItem::fetch(tl::fetcher& tlf) noexcept {
bool rpcInvokeReqExtra::fetch(tl::fetcher& tlf) noexcept {
bool ok{flags.fetch(tlf)};
if (ok && static_cast<bool>(flags.value & WAIT_BINLOG_POS_FLAG)) {
ok &= opt_wait_binlog_pos.emplace().fetch(tlf);
ok = ok && opt_wait_binlog_pos.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & STRING_FORWARD_KEYS_FLAG)) {
ok &= opt_string_forward_keys.emplace().fetch(tlf);
ok = ok && opt_string_forward_keys.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & INT_FORWARD_KEYS_FLAG)) {
ok &= opt_int_forward_keys.emplace().fetch(tlf);
ok = ok && opt_int_forward_keys.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & STRING_FORWARD_FLAG)) {
ok &= opt_string_forward.emplace().fetch(tlf);
ok = ok && opt_string_forward.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & INT_FORWARD_FLAG)) {
ok &= opt_int_forward.emplace().fetch(tlf);
ok = ok && opt_int_forward.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & CUSTOM_TIMEOUT_MS_FLAG)) {
ok &= opt_custom_timeout_ms.emplace().fetch(tlf);
ok = ok && opt_custom_timeout_ms.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & SUPPORTED_COMPRESSION_VERSION_FLAG)) {
ok &= opt_supported_compression_version.emplace().fetch(tlf);
ok = ok && opt_supported_compression_version.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & RANDOM_DELAY_FLAG)) {
ok &= opt_random_delay.emplace().fetch(tlf);
ok = ok && opt_random_delay.emplace().fetch(tlf);
}

return_binlog_pos = static_cast<bool>(flags.value & RETURN_BINLOG_POS_FLAG);
Expand Down Expand Up @@ -327,3 +247,63 @@ size_t rpcReqResultExtra::footprint() const noexcept {
}

} // namespace tl

namespace tl2 {

bool string::fetch(tl::fetcher& tlf) noexcept {
uint8_t first_byte{};
if (const auto opt_first_byte{tlf.fetch_trivial<uint8_t>()}; opt_first_byte) [[likely]] {
first_byte = *opt_first_byte;
} else {
return false;
}

uint64_t string_len{};
switch (first_byte) {
case HUGE_STRING_MAGIC: {
if (tlf.remaining() < HUGE_STRING_SIZE_LEN) [[unlikely]] {
return false;
}
// we allow non-canonical length to speed up some rare implementations
string_len = *tlf.fetch_trivial<uint64_t>();
break;
}
case MEDIUM_STRING_MAGIC: {
if (tlf.remaining() < MEDIUM_STRING_SIZE_LEN) [[unlikely]] {
return false;
}
string_len = MEDIUM_STRING_MAGIC + *tlf.fetch_trivial<uint16_t>();
break;
}
default: {
string_len = static_cast<uint64_t>(first_byte);
break;
}
}

if (auto remaining{tlf.remaining()}; remaining < string_len) [[unlikely]] {
kphp::log::warning("not enough space in buffer to fetch string: required {} bytes, remain {} bytes", string_len, remaining);
return false;
}

value = {reinterpret_cast<const char*>(std::next(tlf.view().data(), tlf.pos())), static_cast<size_t>(string_len)};
tlf.adjust(string_len);
return true;
}

void string::store(tl::storer& tls) const noexcept {
const size_t str_len{value.size()};

if (str_len <= TINY_STRING_MAX_LEN) {
tls.store_trivial<uint8_t>(str_len);
} else if (str_len <= MEDIUM_STRING_MAX_LEN) {
tls.store_trivial<uint8_t>(MEDIUM_STRING_MAGIC);
tls.store_trivial<uint16_t>(str_len - MEDIUM_STRING_MAGIC);
} else {
tls.store_trivial<uint8_t>(HUGE_STRING_MAGIC);
tls.store_trivial<uint64_t>(str_len);
}
tls.store_bytes({reinterpret_cast<const std::byte*>(value.data()), str_len});
}

} // namespace tl2
Loading
Loading