Skip to content

Commit

Permalink
Refactor Base64 to use EncoderUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Oct 17, 2024
1 parent 9cf55d5 commit 7bf904d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 126 deletions.
111 changes: 9 additions & 102 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,6 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

// Validate the character in charset with ReverseIndex table
constexpr bool checkForwardIndex(
uint8_t index,
const Base64::Charset& charset,
const Base64::ReverseIndex& reverseIndex) {
return (reverseIndex[static_cast<uint8_t>(charset[index])] == index) &&
(index > 0 ? checkForwardIndex(index - 1, charset, reverseIndex) : true);
}

// Verify that for every entry in kBase64Charset, the corresponding entry
// in kBase64ReverseIndexTable is correct.
static_assert(
Expand All @@ -112,28 +103,6 @@ static_assert(
kBase64UrlReverseIndexTable),
"kBase64UrlCharset has incorrect entries");

// Searches for a character within a charset up to a certain index.
constexpr bool findCharacterInCharset(
const Base64::Charset& charset,
uint8_t index,
const char targetChar) {
return index < charset.size() &&
((charset[index] == targetChar) ||
findCharacterInCharset(charset, index + 1, targetChar));
}

// Checks the consistency of a reverse index mapping for a given character
// set.
constexpr bool checkReverseIndex(
uint8_t index,
const Base64::Charset& charset,
const Base64::ReverseIndex& reverseIndex) {
return (reverseIndex[index] == 255
? !findCharacterInCharset(charset, 0, static_cast<char>(index))
: (charset[reverseIndex[index]] == index)) &&
(index > 0 ? checkReverseIndex(index - 1, charset, reverseIndex) : true);
}

// Verify that for every entry in kBase64ReverseIndexTable, the corresponding
// entry in kBase64Charset is correct.
static_assert(
Expand Down Expand Up @@ -166,21 +135,6 @@ std::string Base64::encodeImpl(
return encodedResult;
}

// static
size_t Base64::calculateEncodedSize(size_t inputSize, bool includePadding) {
if (inputSize == 0) {
return 0;
}

// Calculate the output size assuming that we are including padding.
size_t encodedSize = ((inputSize + 2) / 3) * 4;
if (!includePadding) {
// If the padding was not requested, subtract the padding bytes.
encodedSize -= (3 - (inputSize % 3)) % 3;
}
return encodedSize;
}

// static
Status Base64::encode(std::string_view input, std::string& output) {
return encodeImpl(input, kBase64Charset, true, output);
Expand All @@ -205,7 +159,8 @@ Status Base64::encodeImpl(
}

// Calculate the output size and resize the string beforehand
size_t outputSize = calculateEncodedSize(inputSize, includePadding);
size_t outputSize = calculateEncodedSize(
inputSize, includePadding, kBinaryBlockByteSize, kEncodedBlockByteSize);
output.resize(outputSize); // Resize the output string to the required size

// Use a pointer to write into the pre-allocated buffer
Expand Down Expand Up @@ -337,67 +292,14 @@ uint8_t Base64::base64ReverseLookup(
char encodedChar,
const ReverseIndex& reverseIndex,
Status& status) {
auto reverseLookupValue = reverseIndex[static_cast<uint8_t>(encodedChar)];
if (reverseLookupValue >= 0x40) {
status = Status::UserError(
"decode() - invalid input string: invalid characters");
}
return reverseLookupValue;
return reverseLookup(encodedChar, reverseIndex, status, kCharsetSize);
}

// static
Status Base64::decode(std::string_view input, std::string& output) {
return decodeImpl(input, output, kBase64ReverseIndexTable);
}

// static
Status Base64::calculateDecodedSize(
std::string_view input,
size_t& inputSize,
size_t& decodedSize) {
if (inputSize == 0) {
decodedSize = 0;
return Status::OK();
}

// Check if the input string is padded
if (isPadded(input)) {
// If padded, ensure that the string length is a multiple of the encoded
// block size
if (inputSize % kEncodedBlockByteSize != 0) {
return Status::UserError(
"Base64::decode() - invalid input string: "
"string length is not a multiple of 4.");
}

decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize;
auto paddingCount = numPadding(input);
inputSize -= paddingCount;

// Adjust the needed size by deducting the bytes corresponding to the
// padding from the calculated size.
decodedSize -=
((paddingCount * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) /
kEncodedBlockByteSize;
return Status::OK();
}
// If not padded, calculate extra bytes, if any
auto extraBytes = inputSize % kEncodedBlockByteSize;
decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize;

// Adjust the needed size for extra bytes, if present
if (extraBytes) {
if (extraBytes == 1) {
return Status::UserError(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
}
decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize;
}

return Status::OK();
}

// static
Status Base64::decodeImpl(
std::string_view input,
Expand All @@ -411,7 +313,12 @@ Status Base64::decodeImpl(

// Calculate the decoded size based on the input size
size_t decodedSize;
auto status = calculateDecodedSize(input, inputSize, decodedSize);
auto status = calculateDecodedSize(
input,
inputSize,
decodedSize,
kBinaryBlockByteSize,
kEncodedBlockByteSize);
if (!status.ok()) {
return status;
}
Expand Down
11 changes: 1 addition & 10 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <string>
#include "velox/common/base/GTestMacros.h"
#include "velox/common/base/Status.h"
#include "velox/common/encode/EncoderUtils.h"

namespace facebook::velox::encoding {

Expand Down Expand Up @@ -109,16 +110,6 @@ class Base64 {
std::string& output,
const ReverseIndex& reverseIndex);

// Returns the actual size of the decoded data. Will also remove the padding
// length from the 'inputSize'.
static Status calculateDecodedSize(
std::string_view input,
size_t& inputSize,
size_t& decodedSize);

// Calculates the encoded size based on input size.
static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true);

VELOX_FRIEND_TEST(Base64Test, isPadded);
VELOX_FRIEND_TEST(Base64Test, numPadding);
VELOX_FRIEND_TEST(Base64Test, calculateDecodedSize);
Expand Down
14 changes: 2 additions & 12 deletions velox/common/encode/tests/Base64Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ TEST_F(Base64Test, calculateDecodedSize) {
size_t encoded_size = initialEncodedSize;
size_t decoded_size = 0;
Status status =
Base64::calculateDecodedSize(encodedString, encoded_size, decoded_size);
calculateDecodedSize(encodedString, encoded_size, decoded_size, 3, 4);

if (expectedStatus.ok()) {
EXPECT_EQ(Status::OK(), status);
Expand All @@ -75,21 +75,11 @@ TEST_F(Base64Test, calculateDecodedSize) {
0,
0,
Status::UserError(
"Base64::decode() - invalid input string: string length is not a multiple of 4."));
"decode() - invalid input string length."));
checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 32, 31, 23);
checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 31, 23);
checkDecodedSize("MTIzNDU2Nzg5MA==", 16, 14, 10);
checkDecodedSize("MTIzNDU2Nzg5MA", 14, 14, 10);
}

TEST_F(Base64Test, isPadded) {
EXPECT_TRUE(Base64::isPadded("ABC="));
EXPECT_FALSE(Base64::isPadded("ABC"));
}

TEST_F(Base64Test, numPadding) {
EXPECT_EQ(0, Base64::numPadding("ABC"));
EXPECT_EQ(1, Base64::numPadding("ABC="));
EXPECT_EQ(2, Base64::numPadding("AB=="));
}
} // namespace facebook::velox::encoding
4 changes: 2 additions & 2 deletions velox/functions/prestosql/tests/BinaryFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,10 @@ TEST_F(BinaryFunctionsTest, fromBase64) {

VELOX_ASSERT_USER_THROW(
fromBase64("YQ="),
"Base64::decode() - invalid input string: string length is not a multiple of 4.");
"decode() - invalid input string length.");
VELOX_ASSERT_USER_THROW(
fromBase64("YQ==="),
"Base64::decode() - invalid input string: string length is not a multiple of 4.");
"decode() - invalid input string length.");

// Check encoded strings without padding
EXPECT_EQ("a", fromBase64("YQ"));
Expand Down

0 comments on commit 7bf904d

Please sign in to comment.