From 3e7a09288321d41cc278bf487ff38980c736c3ed Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Thu, 3 Oct 2024 12:21:23 +0530 Subject: [PATCH] Update decode as non-throwing API --- velox/common/encode/Base64.cpp | 121 +++++++++++--------- velox/common/encode/Base64.h | 14 ++- velox/common/encode/CMakeLists.txt | 5 +- velox/common/encode/tests/Base64Test.cpp | 39 ++++--- velox/common/encode/tests/CMakeLists.txt | 7 +- velox/functions/prestosql/BinaryFunctions.h | 29 +++-- 6 files changed, 120 insertions(+), 95 deletions(-) diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index 8c4e75d4a247..4b32c3f9c454 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -18,9 +18,7 @@ #include #include #include -#include - -#include "velox/common/base/Exceptions.h" +#include namespace facebook::velox::encoding { @@ -157,15 +155,16 @@ static_assert( // "kBase64UrlReverseIndexTable has incorrect entries."); // Implementation of Base64 encoding and decoding functions. +// static template -/* static */ std::string Base64::encodeImpl( +std::string Base64::encodeImpl( const T& input, - const Base64::Charset& charset, + const Charset& charset, bool includePadding) { - size_t encodedSize = calculateEncodedSize(input.size(), includePadding); + const size_t encodedSize{calculateEncodedSize(input.size(), includePadding)}; std::string encodedResult; encodedResult.resize(encodedSize); - encodeImpl(input, charset, includePadding, encodedResult.data()); + (void)encodeImpl(input, charset, includePadding, encodedResult.data()); return encodedResult; } @@ -185,26 +184,31 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool withPadding) { } // static -void Base64::encode(const char* input, size_t inputSize, char* output) { - encodeImpl( +Status Base64::encode(const char* input, size_t inputSize, char* output) { + return encodeImpl( folly::StringPiece(input, inputSize), kBase64Charset, true, output); } // static -void Base64::encodeUrl(const char* input, size_t inputSize, char* output) { - encodeImpl( - folly::StringPiece(input, inputSize), kBase64UrlCharset, true, output); +Status +Base64::encodeUrl(const char* input, size_t inputSize, char* outputBuffer) { + return encodeImpl( + folly::StringPiece(input, inputSize), + kBase64UrlCharset, + true, + outputBuffer); } +// static template -/* static */ void Base64::encodeImpl( +Status Base64::encodeImpl( const T& input, const Base64::Charset& charset, bool includePadding, char* outputBuffer) { auto inputSize = input.size(); if (inputSize == 0) { - return; + return Status::OK(); } auto outputPointer = outputBuffer; @@ -213,9 +217,9 @@ template // For each group of 3 bytes (24 bits) in the input, split that into // 4 groups of 6 bits and encode that using the supplied charset lookup for (; inputSize > 2; inputSize -= 3) { - uint32_t inputBlock = uint8_t(*inputIterator++) << 16; - inputBlock |= uint8_t(*inputIterator++) << 8; - inputBlock |= uint8_t(*inputIterator++); + uint32_t inputBlock = static_cast(*inputIterator++) << 16; + inputBlock |= static_cast(*inputIterator++) << 8; + inputBlock |= static_cast(*inputIterator++); *outputPointer++ = charset[(inputBlock >> 18) & 0x3f]; *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; @@ -227,10 +231,10 @@ template // We have either 1 or 2 input bytes left. Encode this similar to the // above (assuming 0 for all other bytes). Optionally append the '=' // character if it is requested. - uint32_t inputBlock = uint8_t(*inputIterator++) << 16; + uint32_t inputBlock = static_cast(*inputIterator++) << 16; *outputPointer++ = charset[(inputBlock >> 18) & 0x3f]; if (inputSize > 1) { - inputBlock |= uint8_t(*inputIterator) << 8; + inputBlock |= static_cast(*inputIterator) << 8; *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; *outputPointer++ = charset[(inputBlock >> 6) & 0x3f]; if (includePadding) { @@ -244,6 +248,7 @@ template } } } + return Status::OK(); } // static @@ -320,23 +325,28 @@ void Base64::decode( const std::pair& payload, std::string& decodedOutput) { size_t inputSize = payload.second; - decodedOutput.resize(calculateDecodedSize(payload.first, inputSize)); - decode(payload.first, inputSize, decodedOutput.data(), decodedOutput.size()); + size_t decodedSize; + (void)calculateDecodedSize(payload.first, inputSize, decodedSize); + decodedOutput.resize(decodedSize); + (void)decode( + payload.first, inputSize, decodedOutput.data(), decodedOutput.size()); } // static void Base64::decode(const char* input, size_t size, char* output) { size_t expectedOutputSize = size / 4 * 3; - Base64::decode(input, size, output, expectedOutputSize); + (void)Base64::decode(input, size, output, expectedOutputSize); } // static uint8_t Base64::base64ReverseLookup( char encodedChar, - const Base64::ReverseIndex& reverseIndex) { - auto reverseLookupValue = reverseIndex[(uint8_t)encodedChar]; + const Base64::ReverseIndex& reverseIndex, + Status& status) { + auto reverseLookupValue = reverseIndex[static_cast(encodedChar)]; if (reverseLookupValue >= 0x40) { - VELOX_USER_FAIL("decode() - invalid input string: invalid characters"); + status = Status::UserError( + "decode() - invalid input string: invalid characters"); } return reverseLookupValue; } @@ -352,9 +362,12 @@ Status Base64::decode( } // static -size_t Base64::calculateDecodedSize(const char* input, size_t& inputSize) { +Status Base64::calculateDecodedSize( + const char* input, + size_t& inputSize, + size_t& decodedSize) { if (inputSize == 0) { - return 0; + return Status::OK(); } // Check if the input string is padded @@ -362,37 +375,37 @@ size_t Base64::calculateDecodedSize(const char* input, size_t& inputSize) { // If padded, ensure that the string length is a multiple of the encoded // block size if (inputSize % kEncodedBlockByteSize != 0) { - VELOX_USER_FAIL( + return Status::UserError( "Base64::decode() - invalid input string: " "string length is not a multiple of 4."); } - auto decodedSize = - (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize; + decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize; auto paddingCount = numPadding(input, inputSize); inputSize -= paddingCount; // Adjust the needed size by deducting the bytes corresponding to the // padding from the calculated size. - return decodedSize - + decodedSize -= ((paddingCount * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) / kEncodedBlockByteSize; + return Status::OK(); } // If not padded, Calculate extra bytes, if any auto extraBytes = inputSize % kEncodedBlockByteSize; - auto decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize; + decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize; // Adjust the needed size for extra bytes, if present if (extraBytes) { if (extraBytes == 1) { - VELOX_USER_FAIL( + return Status::UserError( "Base64::decode() - invalid input string: " "string length cannot be 1 more than a multiple of 4."); } decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize; } - return decodedSize; + return Status::OK(); } // static @@ -406,43 +419,47 @@ Status Base64::decodeImpl( return Status::OK(); } - auto decodedSize = calculateDecodedSize(input, inputSize); + size_t decodedSize; + (void)calculateDecodedSize(input, inputSize, decodedSize); if (outputSize < decodedSize) { return Status::UserError( "Base64::decode() - invalid output string: output string is too small."); } + Status lookupStatus; // Handle full groups of 4 characters for (; inputSize > 4; inputSize -= 4, input += 4, outputBuffer += 3) { // Each character of the 4 encodes 6 bits of the original, grab each with // the appropriate shifts to rebuild the original and then split that back // into the original 8-bit bytes. uint32_t decodedBlock = - (base64ReverseLookup(input[0], reverseIndex) << 18) | - (base64ReverseLookup(input[1], reverseIndex) << 12) | - (base64ReverseLookup(input[2], reverseIndex) << 6) | - base64ReverseLookup(input[3], reverseIndex); - outputBuffer[0] = (decodedBlock >> 16) & 0xff; - outputBuffer[1] = (decodedBlock >> 8) & 0xff; - outputBuffer[2] = decodedBlock & 0xff; + (base64ReverseLookup(input[0], reverseIndex, lookupStatus) << 18) | + (base64ReverseLookup(input[1], reverseIndex, lookupStatus) << 12) | + (base64ReverseLookup(input[2], reverseIndex, lookupStatus) << 6) | + base64ReverseLookup(input[3], reverseIndex, lookupStatus); + outputBuffer[0] = static_cast((decodedBlock >> 16) & 0xff); + outputBuffer[1] = static_cast((decodedBlock >> 8) & 0xff); + outputBuffer[2] = static_cast(decodedBlock & 0xff); } // Handle the last 2-4 characters. This is similar to the above, but the // last 2 characters may or may not exist. DCHECK(inputSize >= 2); - uint32_t decodedBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) | - (base64ReverseLookup(input[1], reverseIndex) << 12); - outputBuffer[0] = (decodedBlock >> 16) & 0xff; + uint32_t decodedBlock = + (base64ReverseLookup(input[0], reverseIndex, lookupStatus) << 18) | + (base64ReverseLookup(input[1], reverseIndex, lookupStatus) << 12); + outputBuffer[0] = static_cast((decodedBlock >> 16) & 0xff); if (inputSize > 2) { - decodedBlock |= base64ReverseLookup(input[2], reverseIndex) << 6; - outputBuffer[1] = (decodedBlock >> 8) & 0xff; + decodedBlock |= base64ReverseLookup(input[2], reverseIndex, lookupStatus) + << 6; + outputBuffer[1] = static_cast((decodedBlock >> 8) & 0xff); if (inputSize > 3) { - decodedBlock |= base64ReverseLookup(input[3], reverseIndex); - outputBuffer[2] = decodedBlock & 0xff; + decodedBlock |= base64ReverseLookup(input[3], reverseIndex, lookupStatus); + outputBuffer[2] = static_cast(decodedBlock & 0xff); } } - return Status::OK(); + return (lookupStatus != Status::OK()) ? lookupStatus : Status::OK(); } // static @@ -483,7 +500,9 @@ void Base64::decodeUrl( const std::pair& payload, std::string& decodedOutput) { size_t inputSize = payload.second; - decodedOutput.resize(calculateDecodedSize(payload.first, inputSize)); + size_t decodedSize; + (void)calculateDecodedSize(payload.first, inputSize, decodedSize); + decodedOutput.resize(decodedSize); (void)Base64::decodeImpl( payload.first, payload.second, diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 841cee17fa76..e8dd49df1985 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -44,13 +44,13 @@ class Base64 { static std::string encode(const char* input, size_t inputSize); static std::string encode(folly::StringPiece text); static std::string encode(const folly::IOBuf* inputBuffer); - static void encode(const char* input, size_t inputSize, char* outputBuffer); + static Status encode(const char* input, size_t inputSize, char* outputBuffer); /// Encodes the input data using Base64 URL encoding. static std::string encodeUrl(const char* input, size_t inputSize); static std::string encodeUrl(folly::StringPiece text); static std::string encodeUrl(const folly::IOBuf* inputBuffer); - static void + static Status encodeUrl(const char* input, size_t inputSize, char* outputBuffer); // Decoding Functions @@ -83,7 +83,10 @@ class Base64 { /// Calculates the decoded size based on encoded input and adjusts the input /// size for padding. - static size_t calculateDecodedSize(const char* input, size_t& inputSize); + static Status calculateDecodedSize( + const char* input, + size_t& inputSize, + size_t& decodedSize); private: // Checks if the input Base64 string is padded. @@ -105,14 +108,15 @@ class Base64 { // character. static uint8_t base64ReverseLookup( char encodedChar, - const ReverseIndex& reverseIndex); + const ReverseIndex& reverseIndex, + Status& status); template static std::string encodeImpl(const T& input, const Charset& charset, bool includePadding); template - static void encodeImpl( + static Status encodeImpl( const T& input, const Charset& charset, bool includePadding, diff --git a/velox/common/encode/CMakeLists.txt b/velox/common/encode/CMakeLists.txt index f67ba8eab577..501c690c476b 100644 --- a/velox/common/encode/CMakeLists.txt +++ b/velox/common/encode/CMakeLists.txt @@ -17,7 +17,4 @@ if(${VELOX_BUILD_TESTING}) endif() velox_add_library(velox_encode Base64.cpp) -velox_link_libraries( - velox_encode - PUBLIC Folly::folly - PRIVATE velox_status) +velox_link_libraries(velox_encode PUBLIC Folly::folly) diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 9cbbbad47124..ecfbf20a09f2 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -50,43 +50,48 @@ TEST_F(Base64Test, fromBase64) { TEST_F(Base64Test, calculateDecodedSizeProperSize) { size_t encoded_size{0}; + size_t decoded_size{0}; encoded_size = 20; - EXPECT_EQ( - 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size)); + Base64::calculateDecodedSize( + "SGVsbG8sIFdvcmxkIQ==", encoded_size, decoded_size); EXPECT_EQ(18, encoded_size); + EXPECT_EQ(13, decoded_size); encoded_size = 18; - EXPECT_EQ( - 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ", encoded_size)); + Base64::calculateDecodedSize( + "SGVsbG8sIFdvcmxkIQ", encoded_size, decoded_size); EXPECT_EQ(18, encoded_size); + EXPECT_EQ(13, decoded_size); encoded_size = 21; - VELOX_ASSERT_THROW( - Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size), - "Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4."); - - encoded_size = 32; EXPECT_EQ( - 23, + Status::UserError( + "Base64::decode() - invalid input string: string length is not a multiple of 4."), Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size)); + "SGVsbG8sIFdvcmxkIQ===", encoded_size, decoded_size)); + + encoded_size = 32; + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size, decoded_size); EXPECT_EQ(31, encoded_size); + EXPECT_EQ(23, decoded_size); encoded_size = 31; - EXPECT_EQ( - 23, - Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size)); + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size, decoded_size); EXPECT_EQ(31, encoded_size); + EXPECT_EQ(23, decoded_size); encoded_size = 16; - EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size)); + Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size, decoded_size); EXPECT_EQ(14, encoded_size); + EXPECT_EQ(10, decoded_size); encoded_size = 14; - EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size)); + Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size, decoded_size); EXPECT_EQ(14, encoded_size); + EXPECT_EQ(10, decoded_size); } TEST_F(Base64Test, checksPadding) { diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index f3bc8b6f0612..63f718c24745 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -17,9 +17,4 @@ add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test PUBLIC Folly::folly - PRIVATE - velox_encode - velox_exception - velox_status - GTest::gtest - GTest::gtest_main) + PRIVATE velox_encode velox_status GTest::gtest GTest::gtest_main) diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index 23926836e83c..d83d34aefb23 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -278,11 +278,10 @@ template struct ToBase64Function { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { result.resize(encoding::Base64::calculateEncodedSize(input.size())); - encoding::Base64::encode(input.data(), input.size(), result.data()); + return encoding::Base64::encode(input.data(), input.size(), result.data()); } }; @@ -295,8 +294,11 @@ struct FromBase64Function { template FOLLY_ALWAYS_INLINE Status call(out_type& result, const T& input) { auto inputSize = input.size(); - result.resize( - encoding::Base64::calculateDecodedSize(input.data(), inputSize)); + size_t decodedSize; + auto status = encoding::Base64::calculateDecodedSize(input.data(), inputSize,decodedSize); + if(status != Status::OK()) + return status; + result.resize(decodedSize); return encoding::Base64::decode( input.data(), inputSize, result.data(), result.size()); } @@ -308,8 +310,11 @@ struct FromBase64UrlFunction { FOLLY_ALWAYS_INLINE Status call(out_type& result, const arg_type& input) { auto inputSize = input.size(); - result.resize( - encoding::Base64::calculateDecodedSize(input.data(), inputSize)); + size_t decodedSize; + auto status = encoding::Base64::calculateDecodedSize(input.data(), inputSize,decodedSize); + if(status != Status::OK()) + return status; + result.resize(decodedSize); return encoding::Base64::decodeUrl( input.data(), inputSize, result.data(), result.size()); } @@ -319,11 +324,11 @@ template struct ToBase64UrlFunction { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { result.resize(encoding::Base64::calculateEncodedSize(input.size())); - encoding::Base64::encodeUrl(input.data(), input.size(), result.data()); + return encoding::Base64::encodeUrl( + input.data(), input.size(), result.data()); } };