3333import org .opensearch .search .aggregations .InternalAggregation ;
3434import org .opensearch .search .aggregations .InternalOrder ;
3535import org .opensearch .search .aggregations .LeafBucketCollector ;
36+ import org .opensearch .search .aggregations .bucket .BucketsAggregator ;
3637import org .opensearch .search .aggregations .bucket .DeferableBucketAggregator ;
3738import org .opensearch .search .aggregations .bucket .LocalBucketCountThresholds ;
3839import org .opensearch .search .aggregations .support .AggregationPath ;
@@ -215,19 +216,11 @@ public InternalAggregation buildEmptyAggregation() {
215216
216217 @ Override
217218 protected LeafBucketCollector getLeafCollector (LeafReaderContext ctx , LeafBucketCollector sub ) throws IOException {
218- MultiTermsValuesSourceCollector collector = multiTermsValue .getValues (ctx );
219+ MultiTermsValuesSourceCollector collector = multiTermsValue .getValues (ctx , bucketOrds , this , sub );
219220 return new LeafBucketCollector () {
220221 @ Override
221222 public void collect (int doc , long owningBucketOrd ) throws IOException {
222- for (BytesRef compositeKey : collector .apply (doc )) {
223- long bucketOrd = bucketOrds .add (owningBucketOrd , compositeKey );
224- if (bucketOrd < 0 ) {
225- bucketOrd = -1 - bucketOrd ;
226- collectExistingBucket (sub , doc , bucketOrd );
227- } else {
228- collectBucket (sub , doc , bucketOrd );
229- }
230- }
223+ collector .apply (doc , owningBucketOrd );
231224 }
232225 };
233226 }
@@ -268,12 +261,10 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept
268261 }
269262 // we need to fill-in the blanks
270263 for (LeafReaderContext ctx : context .searcher ().getTopReaderContext ().leaves ()) {
271- MultiTermsValuesSourceCollector collector = multiTermsValue .getValues (ctx );
264+ MultiTermsValuesSourceCollector collector = multiTermsValue .getValues (ctx , bucketOrds , null , null );
272265 // brute force
273266 for (int docId = 0 ; docId < ctx .reader ().maxDoc (); ++docId ) {
274- for (BytesRef compositeKey : collector .apply (docId )) {
275- bucketOrds .add (owningBucketOrd , compositeKey );
276- }
267+ collector .apply (docId , owningBucketOrd );
277268 }
278269 }
279270 }
@@ -287,7 +278,8 @@ interface MultiTermsValuesSourceCollector {
287278 * Collect a list values of multi_terms on each doc.
288279 * Each terms could have multi_values, so the result is the cartesian product of each term's values.
289280 */
290- List <BytesRef > apply (int doc ) throws IOException ;
281+ void apply (int doc , long owningBucketOrd ) throws IOException ;
282+
291283 }
292284
293285 @ FunctionalInterface
@@ -361,51 +353,17 @@ public MultiTermsValuesSource(List<InternalValuesSource> valuesSources) {
361353 this .valuesSources = valuesSources ;
362354 }
363355
364- public MultiTermsValuesSourceCollector getValues (LeafReaderContext ctx ) throws IOException {
356+ public MultiTermsValuesSourceCollector getValues (
357+ LeafReaderContext ctx ,
358+ BytesKeyedBucketOrds bucketOrds ,
359+ BucketsAggregator aggregator ,
360+ LeafBucketCollector sub
361+ ) throws IOException {
365362 List <InternalValuesSourceCollector > collectors = new ArrayList <>();
366363 for (InternalValuesSource valuesSource : valuesSources ) {
367364 collectors .add (valuesSource .apply (ctx ));
368365 }
369- return new MultiTermsValuesSourceCollector () {
370- @ Override
371- public List <BytesRef > apply (int doc ) throws IOException {
372- List <List <TermValue <?>>> collectedValues = new ArrayList <>();
373- for (InternalValuesSourceCollector collector : collectors ) {
374- collectedValues .add (collector .apply (doc ));
375- }
376- List <BytesRef > result = new ArrayList <>();
377- scratch .seek (0 );
378- scratch .writeVInt (collectors .size ()); // number of fields per composite key
379- cartesianProduct (result , scratch , collectedValues , 0 );
380- return result ;
381- }
382-
383- /**
384- * Cartesian product using depth first search.
385- *
386- * <p>
387- * Composite keys are encoded to a {@link BytesRef} in a format compatible with {@link StreamOutput::writeGenericValue},
388- * but reuses the encoding of the shared prefixes from the previous levels to avoid wasteful work.
389- */
390- private void cartesianProduct (
391- List <BytesRef > compositeKeys ,
392- BytesStreamOutput scratch ,
393- List <List <TermValue <?>>> collectedValues ,
394- int index
395- ) throws IOException {
396- if (collectedValues .size () == index ) {
397- compositeKeys .add (BytesRef .deepCopyOf (scratch .bytes ().toBytesRef ()));
398- return ;
399- }
400-
401- long position = scratch .position ();
402- for (TermValue <?> value : collectedValues .get (index )) {
403- value .writeTo (scratch ); // encode the value
404- cartesianProduct (compositeKeys , scratch , collectedValues , index + 1 ); // dfs
405- scratch .seek (position ); // backtrack
406- }
407- }
408- };
366+ return new MultiValuesSourceCollectorImpl (collectors , scratch , bucketOrds , aggregator , sub );
409367 }
410368
411369 @ Override
@@ -414,6 +372,74 @@ public void close() {
414372 }
415373 }
416374
375+ static class MultiValuesSourceCollectorImpl implements MultiTermsValuesSourceCollector {
376+
377+ private final List <InternalValuesSourceCollector > collectors ;
378+ private final BytesStreamOutput scratch ;
379+ private final BytesKeyedBucketOrds bucketOrds ;
380+ private final BucketsAggregator aggregator ;
381+ private final LeafBucketCollector sub ;
382+
383+ private final boolean collectViaAggregator ;
384+
385+ public MultiValuesSourceCollectorImpl (
386+ List <InternalValuesSourceCollector > collectors ,
387+ BytesStreamOutput scratch ,
388+ BytesKeyedBucketOrds bucketOrds ,
389+ BucketsAggregator aggregator ,
390+ LeafBucketCollector sub
391+ ) {
392+ this .collectors = collectors ;
393+ this .scratch = scratch ;
394+ this .bucketOrds = bucketOrds ;
395+ this .aggregator = aggregator ;
396+ this .sub = sub ;
397+ this .collectViaAggregator = aggregator != null && sub != null ;
398+ }
399+
400+ @ Override
401+ public void apply (int doc , long owningBucketOrd ) throws IOException {
402+ List <List <TermValue <?>>> collectedValues = new ArrayList <>();
403+ for (InternalValuesSourceCollector collector : collectors ) {
404+ collectedValues .add (collector .apply (doc ));
405+ }
406+ scratch .seek (0 );
407+ scratch .writeVInt (collectors .size ()); // number of fields per composite key
408+ cartesianProductRecursive (collectedValues , 0 , owningBucketOrd , doc );
409+ }
410+
411+ /**
412+ * Cartesian product using depth first search.
413+ */
414+ private void cartesianProductRecursive (List <List <TermValue <?>>> collectedValues , int index , long owningBucketOrd , int doc )
415+ throws IOException {
416+ if (collectedValues .size () == index ) {
417+ // Avoid performing a deep copy of the composite key
418+ long bucketOrd = bucketOrds .add (owningBucketOrd , scratch .bytes ().toBytesRef ());
419+ if (collectViaAggregator ) {
420+ if (bucketOrd < 0 ) {
421+ bucketOrd = -1 - bucketOrd ;
422+ aggregator .collectExistingBucket (sub , doc , bucketOrd );
423+ } else {
424+ aggregator .collectBucket (sub , doc , bucketOrd );
425+ }
426+ }
427+ return ;
428+ }
429+
430+ long position = scratch .position ();
431+ List <TermValue <?>> values = collectedValues .get (index );
432+ int numIterations = values .size ();
433+ for (int i = 0 ; i < numIterations ; i ++) {
434+ TermValue <?> value = values .get (i );
435+ value .writeTo (scratch ); // encode the value
436+ cartesianProductRecursive (collectedValues , index + 1 , owningBucketOrd , doc ); // dfs
437+ scratch .seek (position ); // backtrack
438+ }
439+ }
440+
441+ }
442+
417443 /**
418444 * Factory for construct {@link InternalValuesSource}.
419445 *
@@ -441,9 +467,13 @@ static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, Include
441467 if (i > 0 && bytes .equals (previous )) {
442468 continue ;
443469 }
444- BytesRef copy = BytesRef .deepCopyOf (bytes );
445- termValues .add (TermValue .of (copy ));
446- previous = copy ;
470+ if (valuesCount > 1 ) {
471+ BytesRef copy = BytesRef .deepCopyOf (bytes );
472+ termValues .add (TermValue .of (copy ));
473+ previous = copy ;
474+ } else {
475+ termValues .add (TermValue .of (bytes ));
476+ }
447477 }
448478 return termValues ;
449479 };
0 commit comments