From 8b8b9e13dc7b851c0ece4eb4f3c302e5edd53da4 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Wed, 7 Aug 2024 18:44:55 +0530 Subject: [PATCH] Introduce utility class for encoding --- velox/common/encode/Base64.cpp | 61 ++----- velox/common/encode/Base64.h | 23 +-- velox/common/encode/EncoderUtils.h | 117 ++++++++++++ velox/common/encode/tests/Base64Test.cpp | 10 -- velox/common/encode/tests/CMakeLists.txt | 2 +- .../common/encode/tests/EncoderUtilsTests.cpp | 168 ++++++++++++++++++ 6 files changed, 299 insertions(+), 82 deletions(-) create mode 100644 velox/common/encode/EncoderUtils.h create mode 100644 velox/common/encode/tests/EncoderUtilsTests.cpp diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index da4e9cdbfcfd..6eb1ac2da143 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -24,6 +24,9 @@ namespace facebook::velox::encoding { +// Encoding base to be used. +constexpr static int kBase = 64; + // Constants defining the size in bytes of binary and encoded blocks for Base64 // encoding. // Size of a binary block in bytes (3 bytes = 24 bits) @@ -87,15 +90,6 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}; -// Validate the character in charset with ReverseIndex table -constexpr bool checkForwardIndex( - uint8_t idx, - const Base64::Charset& charset, - const Base64::ReverseIndex& reverseIndex) { - return (reverseIndex[static_cast(charset[idx])] == idx) && - (idx > 0 ? checkForwardIndex(idx - 1, charset, reverseIndex) : true); -} - // Verify that for every entry in kBase64Charset, the corresponding entry // in kBase64ReverseIndexTable is correct. static_assert( @@ -114,32 +108,12 @@ static_assert( kBase64UrlReverseIndexTable), "kBase64UrlCharset has incorrect entries"); -// Searches for a character within a charset up to a certain index. -constexpr bool findCharacterInCharset( - const Base64::Charset& charset, - uint8_t idx, - const char c) { - return idx < charset.size() && - ((charset[idx] == c) || findCharacterInCharset(charset, idx + 1, c)); -} - -// Checks the consistency of a reverse index mapping for a given character -// set. -constexpr bool checkReverseIndex( - uint8_t idx, - const Base64::Charset& charset, - const Base64::ReverseIndex& reverseIndex) { - return (reverseIndex[idx] == 255 - ? !findCharacterInCharset(charset, 0, static_cast(idx)) - : (charset[reverseIndex[idx]] == idx)) && - (idx > 0 ? checkReverseIndex(idx - 1, charset, reverseIndex) : true); -} - // Verify that for every entry in kBase64ReverseIndexTable, the corresponding // entry in kBase64Charset is correct. static_assert( checkReverseIndex( sizeof(kBase64ReverseIndexTable) - 1, + kBase, kBase64Charset, kBase64ReverseIndexTable), "kBase64ReverseIndexTable has incorrect entries."); @@ -326,17 +300,6 @@ void Base64::decode(const char* data, size_t size, char* output) { Base64::decode(data, size, output, out_len); } -// static -uint8_t Base64::base64ReverseLookup( - char p, - const Base64::ReverseIndex& reverseIndex) { - auto curr = reverseIndex[(uint8_t)p]; - if (curr >= 0x40) { - VELOX_USER_FAIL("decode() - invalid input string: invalid characters"); - } - return curr; -} - // static size_t Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { @@ -409,10 +372,10 @@ size_t Base64::decodeImpl( // Each character of the 4 encode 6 bits of the original, grab each with // the appropriate shifts to rebuild the original and then split that back // into the original 8 bit bytes. - uint32_t last = (base64ReverseLookup(src[0], reverseIndex) << 18) | - (base64ReverseLookup(src[1], reverseIndex) << 12) | - (base64ReverseLookup(src[2], reverseIndex) << 6) | - base64ReverseLookup(src[3], reverseIndex); + uint32_t last = (baseReverseLookup(kBase, src[0], reverseIndex) << 18) | + (baseReverseLookup(kBase, src[1], reverseIndex) << 12) | + (baseReverseLookup(kBase, src[2], reverseIndex) << 6) | + baseReverseLookup(kBase, src[3], reverseIndex); dst[0] = (last >> 16) & 0xff; dst[1] = (last >> 8) & 0xff; dst[2] = last & 0xff; @@ -421,14 +384,14 @@ size_t Base64::decodeImpl( // Handle the last 2-4 characters. This is similar to the above, but the // last 2 characters may or may not exist. DCHECK(src_len >= 2); - uint32_t last = (base64ReverseLookup(src[0], reverseIndex) << 18) | - (base64ReverseLookup(src[1], reverseIndex) << 12); + uint32_t last = (baseReverseLookup(kBase, src[0], reverseIndex) << 18) | + (baseReverseLookup(kBase, src[1], reverseIndex) << 12); dst[0] = (last >> 16) & 0xff; if (src_len > 2) { - last |= base64ReverseLookup(src[2], reverseIndex) << 6; + last |= baseReverseLookup(kBase, src[2], reverseIndex) << 6; dst[1] = (last >> 8) & 0xff; if (src_len > 3) { - last |= base64ReverseLookup(src[3], reverseIndex); + last |= baseReverseLookup(kBase, src[3], reverseIndex); dst[2] = last & 0xff; } } diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 13004175379a..7972b1143d4b 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -24,6 +24,7 @@ #include #include "velox/common/base/GTestMacros.h" +#include "velox/common/encode/EncoderUtils.h" namespace facebook::velox::encoding { @@ -112,25 +113,6 @@ class Base64 { decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len); private: - /// Checks if there is padding in encoded data. - static inline bool isPadded(const char* data, size_t len) { - return (len > 0 && data[len - 1] == kPadding); - } - - /// Counts the number of padding characters in encoded data. - static inline size_t numPadding(const char* src, size_t len) { - size_t numPadding{0}; - while (len > 0 && src[len - 1] == kPadding) { - numPadding++; - len--; - } - return numPadding; - } - - /// Performs a reverse lookup in the reverse index to retrieve the original - /// index of a character in the base. - static uint8_t base64ReverseLookup(char p, const ReverseIndex& reverseIndex); - /// Encodes the specified data using the provided charset. template static std::string @@ -151,9 +133,6 @@ class Base64 { char* dst, size_t dst_len, const ReverseIndex& table); - - VELOX_FRIEND_TEST(Base64Test, checksPadding); - VELOX_FRIEND_TEST(Base64Test, countsPaddingCorrectly); }; } // namespace facebook::velox::encoding diff --git a/velox/common/encode/EncoderUtils.h b/velox/common/encode/EncoderUtils.h new file mode 100644 index 000000000000..0dea302a36dc --- /dev/null +++ b/velox/common/encode/EncoderUtils.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::encoding { + +const size_t kCharsetSize = 64; +const size_t kReverseIndexSize = 256; + +/// Character set used for encoding purposes. +/// Contains specific characters that form the encoding scheme. +using Charset = std::array; + +/// Reverse lookup table for decoding purposes. +/// Maps each possible encoded character to its corresponding numeric value +/// within the encoding base. +using ReverseIndex = std::array; + +/// Padding character used in encoding. +const static char kPadding = '='; +/// Checks if there is padding in encoded data. +static inline bool isPadded(const char* data, size_t len) { + return (len > 0 && data[len - 1] == kPadding) ? true : false; +} + +/// Counts the number of padding characters in encoded data. +static inline size_t numPadding(const char* src, size_t len) { + size_t numPadding{0}; + while (len > 0 && src[len - 1] == kPadding) { + numPadding++; + len--; + } + return numPadding; +} +/// Performs a reverse lookup in the reverse index to retrieve the original +/// index of a character in the base. +inline uint8_t +baseReverseLookup(int base, char p, const ReverseIndex& reverseIndex) { + auto curr = reverseIndex[(uint8_t)p]; + if (curr >= base) { + VELOX_USER_FAIL("decode() - invalid input string: invalid characters"); + } + return curr; +} + +// Validate the character in charset with ReverseIndex table +static constexpr bool checkForwardIndex( + uint8_t idx, + const Charset& charset, + const ReverseIndex& reverseIndex) { + for (uint8_t i = 0; i <= idx; ++i) { + if (!(reverseIndex[static_cast(charset[i])] == i)) { + return false; + } + } + return true; +} + +/// Searches for a character within a charset up to a certain index. +constexpr bool findCharacterInCharSet( + const Charset& charset, + int base, + uint8_t idx, + const char c) { + for (; idx < base; ++idx) { + if (charset[idx] == c) { + return true; + } + } + return false; +} + +/// Checks the consistency of a reverse index mapping for a given character +/// set. +static constexpr bool checkReverseIndex( + uint8_t idx, + int base, + const Charset& charset, + const ReverseIndex& reverseIndex) { + for (uint8_t currentIdx = idx; currentIdx != static_cast(-1); + --currentIdx) { + if (reverseIndex[currentIdx] == 255) { + if (findCharacterInCharSet( + charset, base, 0, static_cast(currentIdx))) { + return false; + } + } else { + if (!(charset[reverseIndex[currentIdx]] == currentIdx)) { + return false; + } + } + } + return true; +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 9cbbbad47124..e62f3f99e8ad 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -89,14 +89,4 @@ TEST_F(Base64Test, calculateDecodedSizeProperSize) { EXPECT_EQ(14, encoded_size); } -TEST_F(Base64Test, checksPadding) { - EXPECT_TRUE(Base64::isPadded("ABC=", 4)); - EXPECT_FALSE(Base64::isPadded("ABC", 3)); -} - -TEST_F(Base64Test, countsPaddingCorrectly) { - EXPECT_EQ(0, Base64::numPadding("ABC", 3)); - EXPECT_EQ(1, Base64::numPadding("ABC=", 4)); - EXPECT_EQ(2, Base64::numPadding("AB==", 4)); -} } // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 90c9733ecf22..2e1e79ea222e 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_common_encode_test Base64Test.cpp) +add_executable(velox_common_encode_test Base64Test.cpp EncoderUtilsTests.cpp) add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test diff --git a/velox/common/encode/tests/EncoderUtilsTests.cpp b/velox/common/encode/tests/EncoderUtilsTests.cpp new file mode 100644 index 000000000000..f7e02bc06912 --- /dev/null +++ b/velox/common/encode/tests/EncoderUtilsTests.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/encode/EncoderUtils.h" + +namespace facebook::velox::encoding { +class EncoderUtilsTest : public ::testing::Test {}; + +TEST_F(EncoderUtilsTest, isPadded) { + EXPECT_TRUE(isPadded("ABC=", 4)); + EXPECT_FALSE(isPadded("ABC", 3)); +} + +TEST_F(EncoderUtilsTest, numPadding) { + EXPECT_EQ(0, numPadding("ABC", 3)); + EXPECT_EQ(1, numPadding("ABC=", 4)); + EXPECT_EQ(2, numPadding("AB==", 4)); +} + +constexpr Charset testCharset = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; + +constexpr ReverseIndex testReverseIndex = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, + 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, + 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255}; + +constexpr const Charset testWrongCharset = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'}; + +constexpr ReverseIndex testWrongReverseIndex = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, + 62, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, + 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 255, 255, 255, 255, 63, 255, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255}; + +int base64 = 64; + +TEST_F(EncoderUtilsTest, baseReverseLookup) { + EXPECT_NO_THROW(baseReverseLookup(base64, 'A', testReverseIndex)); + EXPECT_THROW( + baseReverseLookup(base64, '=', testReverseIndex), VeloxUserError); +} + +TEST_F(EncoderUtilsTest, checkForwardIndex) { + EXPECT_TRUE(checkForwardIndex(63, testCharset, testReverseIndex)); +} + +TEST_F(EncoderUtilsTest, checkReverseIndex) { + EXPECT_TRUE(checkReverseIndex(255, base64, testCharset, testReverseIndex)); +} + +TEST_F(EncoderUtilsTest, HandlesLookupAndExceptions) { + EXPECT_NO_THROW(baseReverseLookup(base64, 'A', testReverseIndex)); + EXPECT_THROW( + baseReverseLookup(base64, '=', testReverseIndex), VeloxUserError); +} + +TEST_F(EncoderUtilsTest, ValidatesCharsetWithReverseIndex) { + EXPECT_TRUE(checkForwardIndex(63, testCharset, testReverseIndex)); +} + +TEST_F(EncoderUtilsTest, ValidatesReverseIndexWithCharset) { + EXPECT_TRUE(checkReverseIndex( + sizeof(testCharset) - 1, base64, testCharset, testReverseIndex)); + EXPECT_FALSE(checkReverseIndex( + sizeof(testWrongCharset) - 1, + base64, + testWrongCharset, + testWrongReverseIndex)); +} + +TEST_F(EncoderUtilsTest, CharacterDoesNotExist) { + Charset charset = {'A', 'B', 'C', 'D', 'E', 'F'}; + EXPECT_TRUE(findCharacterInCharSet(charset, charset.size(), 1, 'C')); + EXPECT_FALSE(findCharacterInCharSet(charset, charset.size(), 1, 'A')); + EXPECT_TRUE(findCharacterInCharSet(charset, charset.size(), 0, 'A')); +} + +TEST_F(EncoderUtilsTest, EmptyCharset) { + Charset emptyCharset; + EXPECT_FALSE( + findCharacterInCharSet(emptyCharset, base64, emptyCharset.size(), 'A')); +} + +TEST_F(EncoderUtilsTest, CheckForwardIndex_PartialValid) { + // Test partial index range + EXPECT_TRUE(checkForwardIndex(10, testCharset, testReverseIndex)); + EXPECT_TRUE(checkForwardIndex(20, testCharset, testReverseIndex)); + EXPECT_TRUE(checkForwardIndex(30, testCharset, testReverseIndex)); +} + +TEST_F(EncoderUtilsTest, CheckForwardIndex_CorruptedReverseIndex) { + // Corrupting reverse index + ReverseIndex corruptedReverseIndex = testReverseIndex; + corruptedReverseIndex['A'] = 255; + EXPECT_FALSE(checkForwardIndex(63, testCharset, corruptedReverseIndex)); +} + +TEST_F(EncoderUtilsTest, CheckReverseIndex_PartialValid) { + // Test partial index range + EXPECT_TRUE(checkReverseIndex(10, base64, testCharset, testReverseIndex)); + EXPECT_TRUE(checkReverseIndex(20, base64, testCharset, testReverseIndex)); + EXPECT_TRUE(checkReverseIndex(30, base64, testCharset, testReverseIndex)); +} + +TEST_F(EncoderUtilsTest, CheckReverseIndex_CorruptedCharset) { + // Corrupting charset + Charset corruptedCharset = testCharset; + corruptedCharset[10] = '@'; + EXPECT_FALSE( + checkReverseIndex(64, base64, corruptedCharset, testReverseIndex)); + EXPECT_TRUE(checkReverseIndex(255, base64, testCharset, testReverseIndex)); +} + +} // namespace facebook::velox::encoding