Skip to content

Commit 524f857

Browse files
[OPPRO-279] Add bloom_filter_agg and might_contain SparkSql function (#79)
* add sparksql function bloom_filter_agg and might_contain Change bit_ size to fix TPCDS performance * change to statefil function * optimize MightContain * change back to spark value * fix merge bloomfilter * remove comment
1 parent bc489b1 commit 524f857

File tree

15 files changed

+638
-40
lines changed

15 files changed

+638
-40
lines changed

velox/common/base/BloomFilter.h

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
#include <vector>
2121

2222
#include <folly/Hash.h>
23-
2423
#include "velox/common/base/BitUtil.h"
24+
#include "velox/common/base/Exceptions.h"
25+
#include "velox/common/base/IOUtils.h"
26+
#include "velox/type/StringView.h"
2527

2628
namespace facebook::velox {
2729

@@ -31,9 +33,15 @@ namespace facebook::velox {
3133
// expected entry, we get ~2% false positives. 'hashInput' determines
3234
// if the value added or checked needs to be hashed. If this is false,
3335
// we assume that the input is already a 64 bit hash number.
34-
template <bool hashInput = true>
36+
// case:
37+
// InputType can be one of folly hasher support type when hashInput = false
38+
// InputType can only be uint64_t when hashInput = true
39+
template <class InputType = uint64_t, bool hashInput = true>
3540
class BloomFilter {
3641
public:
42+
BloomFilter(){};
43+
BloomFilter(std::vector<uint64_t> bits) : bits_(bits){};
44+
3745
// Prepares 'this' for use with an expected 'capacity'
3846
// entries. Drops any prior content.
3947
void reset(int32_t capacity) {
@@ -42,18 +50,61 @@ class BloomFilter {
4250
bits_.resize(std::max<int32_t>(4, bits::nextPowerOfTwo(capacity) / 4));
4351
}
4452

53+
bool isSet() {
54+
return bits_.size() > 0;
55+
}
56+
4557
// Adds 'value'.
46-
void insert(uint64_t value) {
58+
void insert(InputType value) {
4759
set(bits_.data(),
4860
bits_.size(),
49-
hashInput ? folly::hasher<uint64_t>()(value) : value);
61+
hashInput ? folly::hasher<InputType>()(value) : value);
5062
}
5163

52-
bool mayContain(uint64_t value) const {
64+
bool mayContain(InputType value) const {
5365
return test(
5466
bits_.data(),
5567
bits_.size(),
56-
hashInput ? folly::hasher<uint64_t>()(value) : value);
68+
hashInput ? folly::hasher<InputType>()(value) : value);
69+
}
70+
71+
// Combines the two bloomFilter bits_ using bitwise OR.
72+
void merge(BloomFilter& bloomFilter) {
73+
if (bits_.size() == 0) {
74+
bits_ = bloomFilter.bits_;
75+
return;
76+
} else if (bloomFilter.bits_.size() == 0) {
77+
return;
78+
}
79+
VELOX_CHECK_EQ(bits_.size(), bloomFilter.bits_.size());
80+
for (auto i = 0; i < bloomFilter.bits_.size(); i++) {
81+
bits_[i] |= bloomFilter.bits_[i];
82+
}
83+
}
84+
85+
uint32_t serializedSize() {
86+
return 4 /* number of bits */
87+
+ bits_.size() * 8;
88+
}
89+
90+
void serialize(StringView& output) {
91+
char* outputBuffer = const_cast<char*>(output.data());
92+
common::OutputByteStream stream(outputBuffer);
93+
stream.appendOne((int32_t)bits_.size());
94+
for (auto bit : bits_) {
95+
stream.appendOne(bit);
96+
}
97+
}
98+
99+
static void deserialize(const char* serialized, BloomFilter& output) {
100+
common::InputByteStream stream(serialized);
101+
auto size = stream.read<int32_t>();
102+
output.bits_.resize(size);
103+
auto bitsdata =
104+
reinterpret_cast<const uint64_t*>(serialized + stream.offset());
105+
for (auto i = 0; i < size; i++) {
106+
output.bits_[i] = bitsdata[i];
107+
}
57108
}
58109

59110
private:

velox/common/base/tests/BloomFilterTest.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using namespace facebook::velox;
2424

2525
TEST(BloomFilterTest, basic) {
2626
constexpr int32_t kSize = 1024;
27-
BloomFilter bloom;
27+
BloomFilter<int32_t> bloom;
2828
bloom.reset(kSize);
2929
for (auto i = 0; i < kSize; ++i) {
3030
bloom.insert(i);
@@ -37,3 +37,46 @@ TEST(BloomFilterTest, basic) {
3737
}
3838
EXPECT_GT(2, 100 * numFalsePositives / kSize);
3939
}
40+
41+
TEST(BloomFilterTest, serialize) {
42+
constexpr int32_t kSize = 1024;
43+
BloomFilter<int32_t> bloom;
44+
bloom.reset(kSize);
45+
for (auto i = 0; i < kSize; ++i) {
46+
bloom.insert(i);
47+
}
48+
std::string data;
49+
data.resize(bloom.serializedSize());
50+
StringView serialized(data.data(), data.size());
51+
bloom.serialize(serialized);
52+
BloomFilter<int32_t> deserialized;
53+
BloomFilter<int32_t>::deserialize(data.data(), deserialized);
54+
for (auto i = 0; i < kSize; ++i) {
55+
EXPECT_TRUE(deserialized.mayContain(i));
56+
}
57+
58+
EXPECT_EQ(bloom.serializedSize(), deserialized.serializedSize());
59+
}
60+
61+
TEST(BloomFilterTest, merge) {
62+
constexpr int32_t kSize = 10;
63+
BloomFilter<int32_t> bloom;
64+
bloom.reset(kSize);
65+
for (auto i = 0; i < kSize; ++i) {
66+
bloom.insert(i);
67+
}
68+
69+
BloomFilter<int32_t> merge;
70+
merge.reset(kSize);
71+
for (auto i = kSize; i < kSize + kSize; i++) {
72+
merge.insert(i);
73+
}
74+
75+
bloom.merge(merge);
76+
77+
for (auto i = 0; i < kSize + kSize; ++i) {
78+
EXPECT_TRUE(bloom.mayContain(i));
79+
}
80+
81+
EXPECT_EQ(bloom.serializedSize(), merge.serializedSize());
82+
}

velox/functions/sparksql/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ add_library(
2727
Size.cpp
2828
SplitFunctions.cpp
2929
String.cpp
30-
Subscript.cpp)
30+
Subscript.cpp
31+
MightContain.cpp)
3132

3233
target_link_libraries(
3334
velox_functions_spark velox_functions_lib velox_functions_prestosql_impl
@@ -36,9 +37,12 @@ target_link_libraries(
3637
set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE
3738
high_memory_pool)
3839

40+
if(${VELOX_ENABLE_AGGREGATES})
41+
add_subdirectory(aggregates)
42+
endif()
43+
3944
if(${VELOX_BUILD_TESTING})
4045
add_subdirectory(tests)
41-
add_subdirectory(aggregates)
4246
endif()
4347

4448
if(${VELOX_ENABLE_BENCHMARKS})
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "velox/functions/sparksql/MightContain.h"
17+
18+
#include "velox/common/base/BloomFilter.h"
19+
#include "velox/expression/DecodedArgs.h"
20+
#include "velox/vector/FlatVector.h"
21+
22+
#include <glog/logging.h>
23+
24+
namespace facebook::velox::functions::sparksql {
25+
namespace {
26+
class BloomFilterMightContainFunction final : public exec::VectorFunction {
27+
bool isDefaultNullBehavior() const final {
28+
return false;
29+
}
30+
31+
void apply(
32+
const SelectivityVector& rows,
33+
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
34+
const TypePtr& outputType,
35+
exec::EvalCtx& context,
36+
VectorPtr& resultRef) const final {
37+
VELOX_CHECK_EQ(args.size(), 2);
38+
context.ensureWritable(rows, BOOLEAN(), resultRef);
39+
auto& result = *resultRef->as<FlatVector<bool>>();
40+
exec::DecodedArgs decodedArgs(rows, args, context);
41+
auto serialized = decodedArgs.at(0);
42+
auto value = decodedArgs.at(1);
43+
if (serialized->isConstantMapping() && serialized->isNullAt(0)) {
44+
rows.applyToSelected([&](int row) { result.setNull(row, true); });
45+
return;
46+
}
47+
48+
if (serialized->isConstantMapping()) {
49+
BloomFilter<int64_t, false> output;
50+
auto serializedBloom = serialized->valueAt<StringView>(0);
51+
BloomFilter<int64_t, false>::deserialize(serializedBloom.data(), output);
52+
rows.applyToSelected([&](int row) {
53+
result.set(row, output.mayContain(value->valueAt<int64_t>(row)));
54+
});
55+
return;
56+
}
57+
58+
rows.applyToSelected([&](int row) {
59+
BloomFilter<int64_t, false> output;
60+
auto serializedBloom = serialized->valueAt<StringView>(row);
61+
BloomFilter<int64_t, false>::deserialize(serializedBloom.data(), output);
62+
result.set(row, output.mayContain(value->valueAt<int64_t>(row)));
63+
});
64+
}
65+
};
66+
} // namespace
67+
68+
std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures() {
69+
return {exec::FunctionSignatureBuilder()
70+
.returnType("boolean")
71+
.argumentType("varbinary")
72+
.argumentType("bigint")
73+
.build()};
74+
}
75+
76+
std::shared_ptr<exec::VectorFunction> makeMightContain(
77+
const std::string& name,
78+
const std::vector<exec::VectorFunctionArg>& inputArgs) {
79+
static const auto kHashFunction =
80+
std::make_shared<BloomFilterMightContainFunction>();
81+
return kHashFunction;
82+
}
83+
84+
} // namespace facebook::velox::functions::sparksql
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "velox/expression/VectorFunction.h"
17+
18+
namespace facebook::velox::functions::sparksql {
19+
20+
std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures();
21+
22+
std::shared_ptr<exec::VectorFunction> makeMightContain(
23+
const std::string& name,
24+
const std::vector<exec::VectorFunctionArg>& inputArgs);
25+
26+
} // namespace facebook::velox::functions::sparksql

velox/functions/sparksql/Register.cpp

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "velox/functions/sparksql/Hash.h"
2929
#include "velox/functions/sparksql/In.h"
3030
#include "velox/functions/sparksql/LeastGreatest.h"
31+
#include "velox/functions/sparksql/MightContain.h"
3132
#include "velox/functions/sparksql/RegexFunctions.h"
3233
#include "velox/functions/sparksql/RegisterArithmetic.h"
3334
#include "velox/functions/sparksql/RegisterCompare.h"
@@ -149,29 +150,27 @@ void registerFunctions(const std::string& prefix) {
149150
exec::registerStatefulVectorFunction(
150151
prefix + "sort_array", sortArraySignatures(), makeSortArray);
151152

153+
// Register bloom filter function
154+
exec::registerStatefulVectorFunction(
155+
prefix + "might_contain", mightContainSignatures(), makeMightContain);
156+
152157
// Register DateTime functions.
153158
registerFunction<MillisecondFunction, int32_t, Date>(
154159
{prefix + "millisecond"});
155160
registerFunction<MillisecondFunction, int32_t, Timestamp>(
156161
{prefix + "millisecond"});
157162
registerFunction<MillisecondFunction, int32_t, TimestampWithTimezone>(
158163
{prefix + "millisecond"});
159-
registerFunction<SecondFunction, int32_t, Date>(
160-
{prefix + "second"});
161-
registerFunction<SecondFunction, int32_t, Timestamp>(
162-
{prefix + "second"});
164+
registerFunction<SecondFunction, int32_t, Date>({prefix + "second"});
165+
registerFunction<SecondFunction, int32_t, Timestamp>({prefix + "second"});
163166
registerFunction<SecondFunction, int32_t, TimestampWithTimezone>(
164167
{prefix + "second"});
165-
registerFunction<MinuteFunction, int32_t, Date>(
166-
{prefix + "minute"});
167-
registerFunction<MinuteFunction, int32_t, Timestamp>(
168-
{prefix + "minute"});
168+
registerFunction<MinuteFunction, int32_t, Date>({prefix + "minute"});
169+
registerFunction<MinuteFunction, int32_t, Timestamp>({prefix + "minute"});
169170
registerFunction<MinuteFunction, int32_t, TimestampWithTimezone>(
170171
{prefix + "minute"});
171-
registerFunction<HourFunction, int32_t, Date>(
172-
{prefix + "hour"});
173-
registerFunction<HourFunction, int32_t, Timestamp>(
174-
{prefix + "hour"});
172+
registerFunction<HourFunction, int32_t, Date>({prefix + "hour"});
173+
registerFunction<HourFunction, int32_t, Timestamp>({prefix + "hour"});
175174
registerFunction<HourFunction, int32_t, TimestampWithTimezone>(
176175
{prefix + "hour"});
177176
registerFunction<DayFunction, int32_t, Date>(
@@ -180,34 +179,26 @@ void registerFunctions(const std::string& prefix) {
180179
{prefix + "day", prefix + "day_of_month"});
181180
registerFunction<DayFunction, int32_t, TimestampWithTimezone>(
182181
{prefix + "day", prefix + "day_of_month"});
183-
registerFunction<DayOfWeekFunction, int32_t, Date>(
184-
{prefix + "day_of_week"});
182+
registerFunction<DayOfWeekFunction, int32_t, Date>({prefix + "day_of_week"});
185183
registerFunction<DayOfWeekFunction, int32_t, Timestamp>(
186184
{prefix + "day_of_week"});
187185
registerFunction<DayOfWeekFunction, int32_t, TimestampWithTimezone>(
188186
{prefix + "day_of_week"});
189-
registerFunction<DayOfYearFunction, int32_t, Date>(
190-
{prefix + "day_of_year"});
187+
registerFunction<DayOfYearFunction, int32_t, Date>({prefix + "day_of_year"});
191188
registerFunction<DayOfYearFunction, int32_t, Timestamp>(
192189
{prefix + "day_of_year"});
193190
registerFunction<DayOfYearFunction, int32_t, TimestampWithTimezone>(
194191
{prefix + "day_of_year"});
195-
registerFunction<MonthFunction, int32_t, Date>(
196-
{prefix + "month"});
197-
registerFunction<MonthFunction, int32_t, Timestamp>(
198-
{prefix + "month"});
192+
registerFunction<MonthFunction, int32_t, Date>({prefix + "month"});
193+
registerFunction<MonthFunction, int32_t, Timestamp>({prefix + "month"});
199194
registerFunction<MonthFunction, int32_t, TimestampWithTimezone>(
200195
{prefix + "month"});
201-
registerFunction<QuarterFunction, int32_t, Date>(
202-
{prefix + "quarter"});
203-
registerFunction<QuarterFunction, int32_t, Timestamp>(
204-
{prefix + "quarter"});
196+
registerFunction<QuarterFunction, int32_t, Date>({prefix + "quarter"});
197+
registerFunction<QuarterFunction, int32_t, Timestamp>({prefix + "quarter"});
205198
registerFunction<QuarterFunction, int32_t, TimestampWithTimezone>(
206199
{prefix + "quarter"});
207-
registerFunction<YearFunction, int32_t, Date>(
208-
{prefix + "year"});
209-
registerFunction<YearFunction, int32_t, Timestamp>(
210-
{prefix + "year"});
200+
registerFunction<YearFunction, int32_t, Date>({prefix + "year"});
201+
registerFunction<YearFunction, int32_t, Timestamp>({prefix + "year"});
211202
registerFunction<YearFunction, int32_t, TimestampWithTimezone>(
212203
{prefix + "year"});
213204
registerFunction<YearOfWeekFunction, int32_t, Date>(

0 commit comments

Comments
 (0)