Skip to content

Commit cf82d7f

Browse files
fix bloomfilter bit_ size
1 parent 2b1e096 commit cf82d7f

File tree

1 file changed

+57
-22
lines changed

1 file changed

+57
-22
lines changed

velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ struct BloomFilterAccumulator {
4747
bloomFilter.merge(output);
4848
}
4949

50+
void init(int32_t capacity) {
51+
if (!bloomFilter.isSet()) {
52+
bloomFilter.reset(capacity);
53+
}
54+
}
55+
5056
BloomFilter<int64_t, false> bloomFilter;
5157
};
5258

@@ -80,9 +86,7 @@ class BloomFilterAggAggregate : public exec::Aggregate {
8086
VELOX_CHECK(!decodedRaw_.mayHaveNulls());
8187
rows.applyToSelected([&](vector_size_t row) {
8288
auto accumulator = value<BloomFilterAccumulator>(groups[row]);
83-
if (!accumulator->bloomFilter.isSet()) {
84-
accumulator->bloomFilter.reset(numBits_);
85-
}
89+
accumulator->init(capacity_);
8690
accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(row));
8791
});
8892
}
@@ -111,18 +115,14 @@ class BloomFilterAggAggregate : public exec::Aggregate {
111115
bool /*mayPushdown*/) override {
112116
decodeArguments(rows, args);
113117
auto accumulator = value<BloomFilterAccumulator>(group);
114-
// if (decodedRaw_.isConstantMapping()) {
115-
// // all values are same, just do for the first
116-
// if (!accumulator->bloomFilter.isSet()) {
117-
// accumulator->bloomFilter.reset(numBits_);
118-
// }
119-
// accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(0));
120-
// return;
121-
// }
118+
if (decodedRaw_.isConstantMapping()) {
119+
// all values are same, just do for the first
120+
accumulator->init(capacity_);
121+
accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(0));
122+
return;
123+
}
122124
rows.applyToSelected([&](vector_size_t row) {
123-
if (!accumulator->bloomFilter.isSet()) {
124-
accumulator->bloomFilter.reset(numBits_);
125-
}
125+
accumulator->init(capacity_);
126126
accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(row));
127127
});
128128
}
@@ -177,16 +177,39 @@ class BloomFilterAggAggregate : public exec::Aggregate {
177177
}
178178

179179
private:
180+
const int64_t DEFAULT_ESPECTED_NUM_ITEMS = 1000000;
181+
const int64_t MAX_NUM_ITEMS = 4000000;
182+
const int64_t MAX_NUM_BITS = 67108864;
183+
180184
void decodeArguments(
181185
const SelectivityVector& rows,
182186
const std::vector<VectorPtr>& args) {
183-
VELOX_CHECK_EQ(args.size(), 3);
184-
decodedRaw_.decode(*args[0], rows);
185-
DecodedVector decodedEstimatedNumItems(*args[1], rows);
186-
DecodedVector decodedNumBits(*args[2], rows);
187-
setConstantArgument(
188-
"estimatedNumItems", estimatedNumItems_, decodedEstimatedNumItems);
189-
setConstantArgument("numBits", numBits_, decodedNumBits);
187+
if (args.size() > 0) {
188+
decodedRaw_.decode(*args[0], rows);
189+
if (args.size() > 1) {
190+
DecodedVector decodedEstimatedNumItems(*args[1], rows);
191+
setConstantArgument(
192+
"estimatedNumItems", estimatedNumItems_, decodedEstimatedNumItems);
193+
if (args.size() > 2) {
194+
DecodedVector decodedNumBits(*args[2], rows);
195+
setConstantArgument("numBits", numBits_, decodedNumBits);
196+
} else {
197+
VELOX_CHECK_EQ(args.size(), 3);
198+
numBits_ = estimatedNumItems_ * 8;
199+
}
200+
} else {
201+
estimatedNumItems_ = DEFAULT_ESPECTED_NUM_ITEMS;
202+
numBits_ = estimatedNumItems_ * 8;
203+
}
204+
} else {
205+
VELOX_USER_FAIL("Function args size must be more than 0")
206+
}
207+
estimatedNumItems_ = std::min(estimatedNumItems_, MAX_NUM_ITEMS);
208+
numBits_ = std::min(numBits_, MAX_NUM_BITS);
209+
// velox BloomFilter bit_ size is bits::nextPowerOfTwo(capacity) / 4, and
210+
// spark bit_ size is Math.ceil(numBits / 64.0) so there is equal bit_ size
211+
// using numBits_ / 16
212+
capacity_ = numBits_ / 16;
190213
}
191214

192215
static void
@@ -211,12 +234,13 @@ class BloomFilterAggAggregate : public exec::Aggregate {
211234
setConstantArgument(name, val, vec.valueAt<int64_t>(0));
212235
}
213236

214-
// Reusable instance of DecodedVector for decoding input vectors.
215237
static constexpr int64_t kMissingArgument = -1;
238+
// Reusable instance of DecodedVector for decoding input vectors.
216239
DecodedVector decodedRaw_;
217240
DecodedVector decodedIntermediate_;
218241
int64_t estimatedNumItems_ = kMissingArgument;
219242
int64_t numBits_ = kMissingArgument;
243+
int32_t capacity_ = kMissingArgument;
220244
};
221245

222246
} // namespace
@@ -226,6 +250,17 @@ bool registerBloomFilterAggAggregate(const std::string& name) {
226250
exec::AggregateFunctionSignatureBuilder()
227251
.argumentType("bigint")
228252
.argumentType("bigint")
253+
.argumentType("bigint")
254+
.intermediateType("varbinary")
255+
.returnType("varbinary")
256+
.build(),
257+
exec::AggregateFunctionSignatureBuilder()
258+
.argumentType("bigint")
259+
.argumentType("bigint")
260+
.intermediateType("varbinary")
261+
.returnType("varbinary")
262+
.build(),
263+
exec::AggregateFunctionSignatureBuilder()
229264
.argumentType("bigint")
230265
.intermediateType("varbinary")
231266
.returnType("varbinary")

0 commit comments

Comments
 (0)