Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions velox/common/base/BloomFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
#include <vector>

#include <folly/Hash.h>

#include "velox/common/base/BitUtil.h"
#include "velox/common/base/Exceptions.h"
#include "velox/common/base/IOUtils.h"
#include "velox/type/StringView.h"

namespace facebook::velox {

Expand All @@ -31,9 +33,15 @@ namespace facebook::velox {
// expected entry, we get ~2% false positives. 'hashInput' determines
// if the value added or checked needs to be hashed. If this is false,
// we assume that the input is already a 64 bit hash number.
template <bool hashInput = true>
// case:
// InputType can be one of folly hasher support type when hashInput = false
// InputType can only be uint64_t when hashInput = true
template <class InputType = uint64_t, bool hashInput = true>
class BloomFilter {
public:
BloomFilter(){};
BloomFilter(std::vector<uint64_t> bits) : bits_(bits){};

// Prepares 'this' for use with an expected 'capacity'
// entries. Drops any prior content.
void reset(int32_t capacity) {
Expand All @@ -42,18 +50,61 @@ class BloomFilter {
bits_.resize(std::max<int32_t>(4, bits::nextPowerOfTwo(capacity) / 4));
}

bool isSet() {
return bits_.size() > 0;
}

// Adds 'value'.
void insert(uint64_t value) {
void insert(InputType value) {
set(bits_.data(),
bits_.size(),
hashInput ? folly::hasher<uint64_t>()(value) : value);
hashInput ? folly::hasher<InputType>()(value) : value);
}

bool mayContain(uint64_t value) const {
bool mayContain(InputType value) const {
return test(
bits_.data(),
bits_.size(),
hashInput ? folly::hasher<uint64_t>()(value) : value);
hashInput ? folly::hasher<InputType>()(value) : value);
}

// Combines the two bloomFilter bits_ using bitwise OR.
void merge(BloomFilter& bloomFilter) {
if (bits_.size() == 0) {
bits_ = bloomFilter.bits_;
return;
} else if (bloomFilter.bits_.size() == 0) {
return;
}
VELOX_CHECK_EQ(bits_.size(), bloomFilter.bits_.size());
for (auto i = 0; i < bloomFilter.bits_.size(); i++) {
bits_[i] |= bloomFilter.bits_[i];
}
}

uint32_t serializedSize() {
return 4 /* number of bits */
+ bits_.size() * 8;
}

void serialize(StringView& output) {
char* outputBuffer = const_cast<char*>(output.data());
common::OutputByteStream stream(outputBuffer);
stream.appendOne((int32_t)bits_.size());
for (auto bit : bits_) {
stream.appendOne(bit);
}
}

static void deserialize(const char* serialized, BloomFilter& output) {
common::InputByteStream stream(serialized);
auto size = stream.read<int32_t>();
output.bits_.resize(size);
auto bitsdata =
reinterpret_cast<const uint64_t*>(serialized + stream.offset());
for (auto i = 0; i < size; i++) {
output.bits_[i] = bitsdata[i];
}
}

private:
Expand Down
45 changes: 44 additions & 1 deletion velox/common/base/tests/BloomFilterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using namespace facebook::velox;

TEST(BloomFilterTest, basic) {
constexpr int32_t kSize = 1024;
BloomFilter bloom;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(i);
Expand All @@ -37,3 +37,46 @@ TEST(BloomFilterTest, basic) {
}
EXPECT_GT(2, 100 * numFalsePositives / kSize);
}

TEST(BloomFilterTest, serialize) {
constexpr int32_t kSize = 1024;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(i);
}
std::string data;
data.resize(bloom.serializedSize());
StringView serialized(data.data(), data.size());
bloom.serialize(serialized);
BloomFilter<int32_t> deserialized;
BloomFilter<int32_t>::deserialize(data.data(), deserialized);
for (auto i = 0; i < kSize; ++i) {
EXPECT_TRUE(deserialized.mayContain(i));
}

EXPECT_EQ(bloom.serializedSize(), deserialized.serializedSize());
}

TEST(BloomFilterTest, merge) {
constexpr int32_t kSize = 10;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(i);
}

BloomFilter<int32_t> merge;
merge.reset(kSize);
for (auto i = kSize; i < kSize + kSize; i++) {
merge.insert(i);
}

bloom.merge(merge);

for (auto i = 0; i < kSize + kSize; ++i) {
EXPECT_TRUE(bloom.mayContain(i));
}

EXPECT_EQ(bloom.serializedSize(), merge.serializedSize());
}
8 changes: 6 additions & 2 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ add_library(
Size.cpp
SplitFunctions.cpp
String.cpp
Subscript.cpp)
Subscript.cpp
MightContain.cpp)

target_link_libraries(
velox_functions_spark velox_functions_lib velox_functions_prestosql_impl
Expand All @@ -36,9 +37,12 @@ target_link_libraries(
set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE
high_memory_pool)

if(${VELOX_ENABLE_AGGREGATES})
add_subdirectory(aggregates)
endif()

if(${VELOX_BUILD_TESTING})
add_subdirectory(tests)
add_subdirectory(aggregates)
endif()

if(${VELOX_ENABLE_BENCHMARKS})
Expand Down
84 changes: 84 additions & 0 deletions velox/functions/sparksql/MightContain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/MightContain.h"

#include "velox/common/base/BloomFilter.h"
#include "velox/expression/DecodedArgs.h"
#include "velox/vector/FlatVector.h"

#include <glog/logging.h>

namespace facebook::velox::functions::sparksql {
namespace {
class BloomFilterMightContainFunction final : public exec::VectorFunction {
bool isDefaultNullBehavior() const final {
return false;
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
VELOX_CHECK_EQ(args.size(), 2);
context.ensureWritable(rows, BOOLEAN(), resultRef);
auto& result = *resultRef->as<FlatVector<bool>>();
exec::DecodedArgs decodedArgs(rows, args, context);
auto serialized = decodedArgs.at(0);
auto value = decodedArgs.at(1);
if (serialized->isConstantMapping() && serialized->isNullAt(0)) {
rows.applyToSelected([&](int row) { result.setNull(row, true); });
return;
}

if (serialized->isConstantMapping()) {
BloomFilter<int64_t, false> output;
auto serializedBloom = serialized->valueAt<StringView>(0);
BloomFilter<int64_t, false>::deserialize(serializedBloom.data(), output);
rows.applyToSelected([&](int row) {
result.set(row, output.mayContain(value->valueAt<int64_t>(row)));
});
return;
}

rows.applyToSelected([&](int row) {
BloomFilter<int64_t, false> output;
auto serializedBloom = serialized->valueAt<StringView>(row);
BloomFilter<int64_t, false>::deserialize(serializedBloom.data(), output);
result.set(row, output.mayContain(value->valueAt<int64_t>(row)));
});
}
};
} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures() {
return {exec::FunctionSignatureBuilder()
.returnType("boolean")
.argumentType("varbinary")
.argumentType("bigint")
.build()};
}

std::shared_ptr<exec::VectorFunction> makeMightContain(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
static const auto kHashFunction =
std::make_shared<BloomFilterMightContainFunction>();
return kHashFunction;
}

} // namespace facebook::velox::functions::sparksql
26 changes: 26 additions & 0 deletions velox/functions/sparksql/MightContain.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures();

std::shared_ptr<exec::VectorFunction> makeMightContain(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

} // namespace facebook::velox::functions::sparksql
47 changes: 19 additions & 28 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "velox/functions/sparksql/Hash.h"
#include "velox/functions/sparksql/In.h"
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/MightContain.h"
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/RegisterArithmetic.h"
#include "velox/functions/sparksql/RegisterCompare.h"
Expand Down Expand Up @@ -149,29 +150,27 @@ void registerFunctions(const std::string& prefix) {
exec::registerStatefulVectorFunction(
prefix + "sort_array", sortArraySignatures(), makeSortArray);

// Register bloom filter function
exec::registerStatefulVectorFunction(
prefix + "might_contain", mightContainSignatures(), makeMightContain);

// Register DateTime functions.
registerFunction<MillisecondFunction, int32_t, Date>(
{prefix + "millisecond"});
registerFunction<MillisecondFunction, int32_t, Timestamp>(
{prefix + "millisecond"});
registerFunction<MillisecondFunction, int32_t, TimestampWithTimezone>(
{prefix + "millisecond"});
registerFunction<SecondFunction, int32_t, Date>(
{prefix + "second"});
registerFunction<SecondFunction, int32_t, Timestamp>(
{prefix + "second"});
registerFunction<SecondFunction, int32_t, Date>({prefix + "second"});
registerFunction<SecondFunction, int32_t, Timestamp>({prefix + "second"});
registerFunction<SecondFunction, int32_t, TimestampWithTimezone>(
{prefix + "second"});
registerFunction<MinuteFunction, int32_t, Date>(
{prefix + "minute"});
registerFunction<MinuteFunction, int32_t, Timestamp>(
{prefix + "minute"});
registerFunction<MinuteFunction, int32_t, Date>({prefix + "minute"});
registerFunction<MinuteFunction, int32_t, Timestamp>({prefix + "minute"});
registerFunction<MinuteFunction, int32_t, TimestampWithTimezone>(
{prefix + "minute"});
registerFunction<HourFunction, int32_t, Date>(
{prefix + "hour"});
registerFunction<HourFunction, int32_t, Timestamp>(
{prefix + "hour"});
registerFunction<HourFunction, int32_t, Date>({prefix + "hour"});
registerFunction<HourFunction, int32_t, Timestamp>({prefix + "hour"});
registerFunction<HourFunction, int32_t, TimestampWithTimezone>(
{prefix + "hour"});
registerFunction<DayFunction, int32_t, Date>(
Expand All @@ -180,34 +179,26 @@ void registerFunctions(const std::string& prefix) {
{prefix + "day", prefix + "day_of_month"});
registerFunction<DayFunction, int32_t, TimestampWithTimezone>(
{prefix + "day", prefix + "day_of_month"});
registerFunction<DayOfWeekFunction, int32_t, Date>(
{prefix + "day_of_week"});
registerFunction<DayOfWeekFunction, int32_t, Date>({prefix + "day_of_week"});
registerFunction<DayOfWeekFunction, int32_t, Timestamp>(
{prefix + "day_of_week"});
registerFunction<DayOfWeekFunction, int32_t, TimestampWithTimezone>(
{prefix + "day_of_week"});
registerFunction<DayOfYearFunction, int32_t, Date>(
{prefix + "day_of_year"});
registerFunction<DayOfYearFunction, int32_t, Date>({prefix + "day_of_year"});
registerFunction<DayOfYearFunction, int32_t, Timestamp>(
{prefix + "day_of_year"});
registerFunction<DayOfYearFunction, int32_t, TimestampWithTimezone>(
{prefix + "day_of_year"});
registerFunction<MonthFunction, int32_t, Date>(
{prefix + "month"});
registerFunction<MonthFunction, int32_t, Timestamp>(
{prefix + "month"});
registerFunction<MonthFunction, int32_t, Date>({prefix + "month"});
registerFunction<MonthFunction, int32_t, Timestamp>({prefix + "month"});
registerFunction<MonthFunction, int32_t, TimestampWithTimezone>(
{prefix + "month"});
registerFunction<QuarterFunction, int32_t, Date>(
{prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, Timestamp>(
{prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, Date>({prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, Timestamp>({prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, TimestampWithTimezone>(
{prefix + "quarter"});
registerFunction<YearFunction, int32_t, Date>(
{prefix + "year"});
registerFunction<YearFunction, int32_t, Timestamp>(
{prefix + "year"});
registerFunction<YearFunction, int32_t, Date>({prefix + "year"});
registerFunction<YearFunction, int32_t, Timestamp>({prefix + "year"});
registerFunction<YearFunction, int32_t, TimestampWithTimezone>(
{prefix + "year"});
registerFunction<YearOfWeekFunction, int32_t, Date>(
Expand Down
Loading