@@ -47,6 +47,12 @@ struct BloomFilterAccumulator {
4747 bloomFilter.merge (output);
4848 }
4949
50+ void init (int32_t capacity) {
51+ if (!bloomFilter.isSet ()) {
52+ bloomFilter.reset (capacity);
53+ }
54+ }
55+
5056 BloomFilter<int64_t , false > bloomFilter;
5157};
5258
@@ -80,9 +86,7 @@ class BloomFilterAggAggregate : public exec::Aggregate {
8086 VELOX_CHECK (!decodedRaw_.mayHaveNulls ());
8187 rows.applyToSelected ([&](vector_size_t row) {
8288 auto accumulator = value<BloomFilterAccumulator>(groups[row]);
83- if (!accumulator->bloomFilter .isSet ()) {
84- accumulator->bloomFilter .reset (numBits_);
85- }
89+ accumulator->init (capacity_);
8690 accumulator->bloomFilter .insert (decodedRaw_.valueAt <int64_t >(row));
8791 });
8892 }
@@ -111,18 +115,14 @@ class BloomFilterAggAggregate : public exec::Aggregate {
111115 bool /* mayPushdown*/ ) override {
112116 decodeArguments (rows, args);
113117 auto accumulator = value<BloomFilterAccumulator>(group);
114- // if (decodedRaw_.isConstantMapping()) {
115- // // all values are same, just do for the first
116- // if (!accumulator->bloomFilter.isSet()) {
117- // accumulator->bloomFilter.reset(numBits_);
118- // }
119- // accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(0));
120- // return;
121- // }
118+ if (decodedRaw_.isConstantMapping ()) {
119+ // all values are same, just do for the first
120+ accumulator->init (capacity_);
121+ accumulator->bloomFilter .insert (decodedRaw_.valueAt <int64_t >(0 ));
122+ return ;
123+ }
122124 rows.applyToSelected ([&](vector_size_t row) {
123- if (!accumulator->bloomFilter .isSet ()) {
124- accumulator->bloomFilter .reset (numBits_);
125- }
125+ accumulator->init (capacity_);
126126 accumulator->bloomFilter .insert (decodedRaw_.valueAt <int64_t >(row));
127127 });
128128 }
@@ -177,16 +177,39 @@ class BloomFilterAggAggregate : public exec::Aggregate {
177177 }
178178
179179 private:
180+ const int64_t DEFAULT_ESPECTED_NUM_ITEMS = 1000000 ;
181+ const int64_t MAX_NUM_ITEMS = 4000000 ;
182+ const int64_t MAX_NUM_BITS = 67108864 ;
183+
180184 void decodeArguments (
181185 const SelectivityVector& rows,
182186 const std::vector<VectorPtr>& args) {
183- VELOX_CHECK_EQ (args.size (), 3 );
184- decodedRaw_.decode (*args[0 ], rows);
185- DecodedVector decodedEstimatedNumItems (*args[1 ], rows);
186- DecodedVector decodedNumBits (*args[2 ], rows);
187- setConstantArgument (
188- " estimatedNumItems" , estimatedNumItems_, decodedEstimatedNumItems);
189- setConstantArgument (" numBits" , numBits_, decodedNumBits);
187+ if (args.size () > 0 ) {
188+ decodedRaw_.decode (*args[0 ], rows);
189+ if (args.size () > 1 ) {
190+ DecodedVector decodedEstimatedNumItems (*args[1 ], rows);
191+ setConstantArgument (
192+ " estimatedNumItems" , estimatedNumItems_, decodedEstimatedNumItems);
193+ if (args.size () > 2 ) {
194+ DecodedVector decodedNumBits (*args[2 ], rows);
195+ setConstantArgument (" numBits" , numBits_, decodedNumBits);
196+ } else {
197+ VELOX_CHECK_EQ (args.size (), 3 );
198+ numBits_ = estimatedNumItems_ * 8 ;
199+ }
200+ } else {
201+ estimatedNumItems_ = DEFAULT_ESPECTED_NUM_ITEMS;
202+ numBits_ = estimatedNumItems_ * 8 ;
203+ }
204+ } else {
205+ VELOX_USER_FAIL (" Function args size must be more than 0" )
206+ }
207+ estimatedNumItems_ = std::min (estimatedNumItems_, MAX_NUM_ITEMS);
208+ numBits_ = std::min (numBits_, MAX_NUM_BITS);
209+ // velox BloomFilter bit_ size is bits::nextPowerOfTwo(capacity) / 4, and
210+ // spark bit_ size is Math.ceil(numBits / 64.0) so there is equal bit_ size
211+ // using numBits_ / 16
212+ capacity_ = numBits_ / 16 ;
190213 }
191214
192215 static void
@@ -211,12 +234,13 @@ class BloomFilterAggAggregate : public exec::Aggregate {
211234 setConstantArgument (name, val, vec.valueAt <int64_t >(0 ));
212235 }
213236
214- // Reusable instance of DecodedVector for decoding input vectors.
215237 static constexpr int64_t kMissingArgument = -1 ;
238+ // Reusable instance of DecodedVector for decoding input vectors.
216239 DecodedVector decodedRaw_;
217240 DecodedVector decodedIntermediate_;
218241 int64_t estimatedNumItems_ = kMissingArgument ;
219242 int64_t numBits_ = kMissingArgument ;
243+ int32_t capacity_ = kMissingArgument ;
220244};
221245
222246} // namespace
@@ -226,6 +250,17 @@ bool registerBloomFilterAggAggregate(const std::string& name) {
226250 exec::AggregateFunctionSignatureBuilder ()
227251 .argumentType (" bigint" )
228252 .argumentType (" bigint" )
253+ .argumentType (" bigint" )
254+ .intermediateType (" varbinary" )
255+ .returnType (" varbinary" )
256+ .build (),
257+ exec::AggregateFunctionSignatureBuilder ()
258+ .argumentType (" bigint" )
259+ .argumentType (" bigint" )
260+ .intermediateType (" varbinary" )
261+ .returnType (" varbinary" )
262+ .build (),
263+ exec::AggregateFunctionSignatureBuilder ()
229264 .argumentType (" bigint" )
230265 .intermediateType (" varbinary" )
231266 .returnType (" varbinary" )
0 commit comments