Skip to content

Commit

Permalink
Update decode as non-throwing API
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Oct 3, 2024
1 parent b5402b4 commit 3e7a092
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 95 deletions.
121 changes: 70 additions & 51 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
#include <folly/Portability.h>
#include <folly/container/Foreach.h>
#include <folly/io/Cursor.h>
#include <stdint.h>

#include "velox/common/base/Exceptions.h"
#include <cstdint>

namespace facebook::velox::encoding {

Expand Down Expand Up @@ -157,15 +155,16 @@ static_assert(
// "kBase64UrlReverseIndexTable has incorrect entries.");

// Implementation of Base64 encoding and decoding functions.
// static
template <class T>
/* static */ std::string Base64::encodeImpl(
std::string Base64::encodeImpl(
const T& input,
const Base64::Charset& charset,
const Charset& charset,
bool includePadding) {
size_t encodedSize = calculateEncodedSize(input.size(), includePadding);
const size_t encodedSize{calculateEncodedSize(input.size(), includePadding)};
std::string encodedResult;
encodedResult.resize(encodedSize);
encodeImpl(input, charset, includePadding, encodedResult.data());
(void)encodeImpl(input, charset, includePadding, encodedResult.data());
return encodedResult;
}

Expand All @@ -185,26 +184,31 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool withPadding) {
}

// static
void Base64::encode(const char* input, size_t inputSize, char* output) {
encodeImpl(
Status Base64::encode(const char* input, size_t inputSize, char* output) {
return encodeImpl(
folly::StringPiece(input, inputSize), kBase64Charset, true, output);
}

// static
void Base64::encodeUrl(const char* input, size_t inputSize, char* output) {
encodeImpl(
folly::StringPiece(input, inputSize), kBase64UrlCharset, true, output);
Status
Base64::encodeUrl(const char* input, size_t inputSize, char* outputBuffer) {
return encodeImpl(
folly::StringPiece(input, inputSize),
kBase64UrlCharset,
true,
outputBuffer);
}

// static
template <class T>
/* static */ void Base64::encodeImpl(
Status Base64::encodeImpl(
const T& input,
const Base64::Charset& charset,
bool includePadding,
char* outputBuffer) {
auto inputSize = input.size();
if (inputSize == 0) {
return;
return Status::OK();
}

auto outputPointer = outputBuffer;
Expand All @@ -213,9 +217,9 @@ template <class T>
// For each group of 3 bytes (24 bits) in the input, split that into
// 4 groups of 6 bits and encode that using the supplied charset lookup
for (; inputSize > 2; inputSize -= 3) {
uint32_t inputBlock = uint8_t(*inputIterator++) << 16;
inputBlock |= uint8_t(*inputIterator++) << 8;
inputBlock |= uint8_t(*inputIterator++);
uint32_t inputBlock = static_cast<uint8_t>(*inputIterator++) << 16;
inputBlock |= static_cast<uint8_t>(*inputIterator++) << 8;
inputBlock |= static_cast<uint8_t>(*inputIterator++);

*outputPointer++ = charset[(inputBlock >> 18) & 0x3f];
*outputPointer++ = charset[(inputBlock >> 12) & 0x3f];
Expand All @@ -227,10 +231,10 @@ template <class T>
// We have either 1 or 2 input bytes left. Encode this similar to the
// above (assuming 0 for all other bytes). Optionally append the '='
// character if it is requested.
uint32_t inputBlock = uint8_t(*inputIterator++) << 16;
uint32_t inputBlock = static_cast<uint8_t>(*inputIterator++) << 16;
*outputPointer++ = charset[(inputBlock >> 18) & 0x3f];
if (inputSize > 1) {
inputBlock |= uint8_t(*inputIterator) << 8;
inputBlock |= static_cast<uint8_t>(*inputIterator) << 8;
*outputPointer++ = charset[(inputBlock >> 12) & 0x3f];
*outputPointer++ = charset[(inputBlock >> 6) & 0x3f];
if (includePadding) {
Expand All @@ -244,6 +248,7 @@ template <class T>
}
}
}
return Status::OK();
}

// static
Expand Down Expand Up @@ -320,23 +325,28 @@ void Base64::decode(
const std::pair<const char*, int32_t>& payload,
std::string& decodedOutput) {
size_t inputSize = payload.second;
decodedOutput.resize(calculateDecodedSize(payload.first, inputSize));
decode(payload.first, inputSize, decodedOutput.data(), decodedOutput.size());
size_t decodedSize;
(void)calculateDecodedSize(payload.first, inputSize, decodedSize);
decodedOutput.resize(decodedSize);
(void)decode(
payload.first, inputSize, decodedOutput.data(), decodedOutput.size());
}

// static
void Base64::decode(const char* input, size_t size, char* output) {
size_t expectedOutputSize = size / 4 * 3;
Base64::decode(input, size, output, expectedOutputSize);
(void)Base64::decode(input, size, output, expectedOutputSize);
}

// static
uint8_t Base64::base64ReverseLookup(
char encodedChar,
const Base64::ReverseIndex& reverseIndex) {
auto reverseLookupValue = reverseIndex[(uint8_t)encodedChar];
const Base64::ReverseIndex& reverseIndex,
Status& status) {
auto reverseLookupValue = reverseIndex[static_cast<uint8_t>(encodedChar)];
if (reverseLookupValue >= 0x40) {
VELOX_USER_FAIL("decode() - invalid input string: invalid characters");
status = Status::UserError(
"decode() - invalid input string: invalid characters");
}
return reverseLookupValue;
}
Expand All @@ -352,47 +362,50 @@ Status Base64::decode(
}

// static
size_t Base64::calculateDecodedSize(const char* input, size_t& inputSize) {
Status Base64::calculateDecodedSize(
const char* input,
size_t& inputSize,
size_t& decodedSize) {
if (inputSize == 0) {
return 0;
return Status::OK();
}

// Check if the input string is padded
if (isPadded(input, inputSize)) {
// If padded, ensure that the string length is a multiple of the encoded
// block size
if (inputSize % kEncodedBlockByteSize != 0) {
VELOX_USER_FAIL(
return Status::UserError(
"Base64::decode() - invalid input string: "
"string length is not a multiple of 4.");
}

auto decodedSize =
(inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize;
decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize;
auto paddingCount = numPadding(input, inputSize);
inputSize -= paddingCount;

// Adjust the needed size by deducting the bytes corresponding to the
// padding from the calculated size.
return decodedSize -
decodedSize -=
((paddingCount * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) /
kEncodedBlockByteSize;
return Status::OK();
}
// If not padded, Calculate extra bytes, if any
auto extraBytes = inputSize % kEncodedBlockByteSize;
auto decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize;
decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize;

// Adjust the needed size for extra bytes, if present
if (extraBytes) {
if (extraBytes == 1) {
VELOX_USER_FAIL(
return Status::UserError(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
}
decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize;
}

return decodedSize;
return Status::OK();
}

// static
Expand All @@ -406,43 +419,47 @@ Status Base64::decodeImpl(
return Status::OK();
}

auto decodedSize = calculateDecodedSize(input, inputSize);
size_t decodedSize;
(void)calculateDecodedSize(input, inputSize, decodedSize);
if (outputSize < decodedSize) {
return Status::UserError(
"Base64::decode() - invalid output string: output string is too small.");
}

Status lookupStatus;
// Handle full groups of 4 characters
for (; inputSize > 4; inputSize -= 4, input += 4, outputBuffer += 3) {
// Each character of the 4 encodes 6 bits of the original, grab each with
// the appropriate shifts to rebuild the original and then split that back
// into the original 8-bit bytes.
uint32_t decodedBlock =
(base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12) |
(base64ReverseLookup(input[2], reverseIndex) << 6) |
base64ReverseLookup(input[3], reverseIndex);
outputBuffer[0] = (decodedBlock >> 16) & 0xff;
outputBuffer[1] = (decodedBlock >> 8) & 0xff;
outputBuffer[2] = decodedBlock & 0xff;
(base64ReverseLookup(input[0], reverseIndex, lookupStatus) << 18) |
(base64ReverseLookup(input[1], reverseIndex, lookupStatus) << 12) |
(base64ReverseLookup(input[2], reverseIndex, lookupStatus) << 6) |
base64ReverseLookup(input[3], reverseIndex, lookupStatus);
outputBuffer[0] = static_cast<char>((decodedBlock >> 16) & 0xff);
outputBuffer[1] = static_cast<char>((decodedBlock >> 8) & 0xff);
outputBuffer[2] = static_cast<char>(decodedBlock & 0xff);
}

// Handle the last 2-4 characters. This is similar to the above, but the
// last 2 characters may or may not exist.
DCHECK(inputSize >= 2);
uint32_t decodedBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12);
outputBuffer[0] = (decodedBlock >> 16) & 0xff;
uint32_t decodedBlock =
(base64ReverseLookup(input[0], reverseIndex, lookupStatus) << 18) |
(base64ReverseLookup(input[1], reverseIndex, lookupStatus) << 12);
outputBuffer[0] = static_cast<char>((decodedBlock >> 16) & 0xff);
if (inputSize > 2) {
decodedBlock |= base64ReverseLookup(input[2], reverseIndex) << 6;
outputBuffer[1] = (decodedBlock >> 8) & 0xff;
decodedBlock |= base64ReverseLookup(input[2], reverseIndex, lookupStatus)
<< 6;
outputBuffer[1] = static_cast<char>((decodedBlock >> 8) & 0xff);
if (inputSize > 3) {
decodedBlock |= base64ReverseLookup(input[3], reverseIndex);
outputBuffer[2] = decodedBlock & 0xff;
decodedBlock |= base64ReverseLookup(input[3], reverseIndex, lookupStatus);
outputBuffer[2] = static_cast<char>(decodedBlock & 0xff);
}
}

return Status::OK();
return (lookupStatus != Status::OK()) ? lookupStatus : Status::OK();
}

// static
Expand Down Expand Up @@ -483,7 +500,9 @@ void Base64::decodeUrl(
const std::pair<const char*, int32_t>& payload,
std::string& decodedOutput) {
size_t inputSize = payload.second;
decodedOutput.resize(calculateDecodedSize(payload.first, inputSize));
size_t decodedSize;
(void)calculateDecodedSize(payload.first, inputSize, decodedSize);
decodedOutput.resize(decodedSize);
(void)Base64::decodeImpl(
payload.first,
payload.second,
Expand Down
14 changes: 9 additions & 5 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ class Base64 {
static std::string encode(const char* input, size_t inputSize);
static std::string encode(folly::StringPiece text);
static std::string encode(const folly::IOBuf* inputBuffer);
static void encode(const char* input, size_t inputSize, char* outputBuffer);
static Status encode(const char* input, size_t inputSize, char* outputBuffer);

/// Encodes the input data using Base64 URL encoding.
static std::string encodeUrl(const char* input, size_t inputSize);
static std::string encodeUrl(folly::StringPiece text);
static std::string encodeUrl(const folly::IOBuf* inputBuffer);
static void
static Status
encodeUrl(const char* input, size_t inputSize, char* outputBuffer);

// Decoding Functions
Expand Down Expand Up @@ -83,7 +83,10 @@ class Base64 {

/// Calculates the decoded size based on encoded input and adjusts the input
/// size for padding.
static size_t calculateDecodedSize(const char* input, size_t& inputSize);
static Status calculateDecodedSize(
const char* input,
size_t& inputSize,
size_t& decodedSize);

private:
// Checks if the input Base64 string is padded.
Expand All @@ -105,14 +108,15 @@ class Base64 {
// character.
static uint8_t base64ReverseLookup(
char encodedChar,
const ReverseIndex& reverseIndex);
const ReverseIndex& reverseIndex,
Status& status);

template <class T>
static std::string
encodeImpl(const T& input, const Charset& charset, bool includePadding);

template <class T>
static void encodeImpl(
static Status encodeImpl(
const T& input,
const Charset& charset,
bool includePadding,
Expand Down
5 changes: 1 addition & 4 deletions velox/common/encode/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,4 @@ if(${VELOX_BUILD_TESTING})
endif()

velox_add_library(velox_encode Base64.cpp)
velox_link_libraries(
velox_encode
PUBLIC Folly::folly
PRIVATE velox_status)
velox_link_libraries(velox_encode PUBLIC Folly::folly)
39 changes: 22 additions & 17 deletions velox/common/encode/tests/Base64Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,43 +50,48 @@ TEST_F(Base64Test, fromBase64) {

TEST_F(Base64Test, calculateDecodedSizeProperSize) {
size_t encoded_size{0};
size_t decoded_size{0};

encoded_size = 20;
EXPECT_EQ(
13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size));
Base64::calculateDecodedSize(
"SGVsbG8sIFdvcmxkIQ==", encoded_size, decoded_size);
EXPECT_EQ(18, encoded_size);
EXPECT_EQ(13, decoded_size);

encoded_size = 18;
EXPECT_EQ(
13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ", encoded_size));
Base64::calculateDecodedSize(
"SGVsbG8sIFdvcmxkIQ", encoded_size, decoded_size);
EXPECT_EQ(18, encoded_size);
EXPECT_EQ(13, decoded_size);

encoded_size = 21;
VELOX_ASSERT_THROW(
Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size),
"Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4.");

encoded_size = 32;
EXPECT_EQ(
23,
Status::UserError(
"Base64::decode() - invalid input string: string length is not a multiple of 4."),
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size));
"SGVsbG8sIFdvcmxkIQ===", encoded_size, decoded_size));

encoded_size = 32;
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size, decoded_size);
EXPECT_EQ(31, encoded_size);
EXPECT_EQ(23, decoded_size);

encoded_size = 31;
EXPECT_EQ(
23,
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size));
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size, decoded_size);
EXPECT_EQ(31, encoded_size);
EXPECT_EQ(23, decoded_size);

encoded_size = 16;
EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size));
Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size, decoded_size);
EXPECT_EQ(14, encoded_size);
EXPECT_EQ(10, decoded_size);

encoded_size = 14;
EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size));
Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size, decoded_size);
EXPECT_EQ(14, encoded_size);
EXPECT_EQ(10, decoded_size);
}

TEST_F(Base64Test, checksPadding) {
Expand Down
Loading

0 comments on commit 3e7a092

Please sign in to comment.