Skip to content

Commit fb84412

Browse files
committed
Avoid deep copy and other allocation improvements
Signed-off-by: expani <anijainc@amazon.com>
1 parent 59302a3 commit fb84412

File tree

1 file changed

+89
-59
lines changed

1 file changed

+89
-59
lines changed

server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java

Lines changed: 89 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.opensearch.search.aggregations.InternalAggregation;
3434
import org.opensearch.search.aggregations.InternalOrder;
3535
import org.opensearch.search.aggregations.LeafBucketCollector;
36+
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
3637
import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator;
3738
import org.opensearch.search.aggregations.bucket.LocalBucketCountThresholds;
3839
import org.opensearch.search.aggregations.support.AggregationPath;
@@ -215,19 +216,11 @@ public InternalAggregation buildEmptyAggregation() {
215216

216217
@Override
217218
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
218-
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
219+
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, this, sub);
219220
return new LeafBucketCollector() {
220221
@Override
221222
public void collect(int doc, long owningBucketOrd) throws IOException {
222-
for (BytesRef compositeKey : collector.apply(doc)) {
223-
long bucketOrd = bucketOrds.add(owningBucketOrd, compositeKey);
224-
if (bucketOrd < 0) {
225-
bucketOrd = -1 - bucketOrd;
226-
collectExistingBucket(sub, doc, bucketOrd);
227-
} else {
228-
collectBucket(sub, doc, bucketOrd);
229-
}
230-
}
223+
collector.apply(doc, owningBucketOrd);
231224
}
232225
};
233226
}
@@ -268,12 +261,10 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept
268261
}
269262
// we need to fill-in the blanks
270263
for (LeafReaderContext ctx : context.searcher().getTopReaderContext().leaves()) {
271-
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
264+
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, null, null);
272265
// brute force
273266
for (int docId = 0; docId < ctx.reader().maxDoc(); ++docId) {
274-
for (BytesRef compositeKey : collector.apply(docId)) {
275-
bucketOrds.add(owningBucketOrd, compositeKey);
276-
}
267+
collector.apply(docId, owningBucketOrd);
277268
}
278269
}
279270
}
@@ -287,7 +278,8 @@ interface MultiTermsValuesSourceCollector {
287278
* Collect a list values of multi_terms on each doc.
288279
* Each terms could have multi_values, so the result is the cartesian product of each term's values.
289280
*/
290-
List<BytesRef> apply(int doc) throws IOException;
281+
void apply(int doc, long owningBucketOrd) throws IOException;
282+
291283
}
292284

293285
@FunctionalInterface
@@ -361,51 +353,17 @@ public MultiTermsValuesSource(List<InternalValuesSource> valuesSources) {
361353
this.valuesSources = valuesSources;
362354
}
363355

364-
public MultiTermsValuesSourceCollector getValues(LeafReaderContext ctx) throws IOException {
356+
public MultiTermsValuesSourceCollector getValues(
357+
LeafReaderContext ctx,
358+
BytesKeyedBucketOrds bucketOrds,
359+
BucketsAggregator aggregator,
360+
LeafBucketCollector sub
361+
) throws IOException {
365362
List<InternalValuesSourceCollector> collectors = new ArrayList<>();
366363
for (InternalValuesSource valuesSource : valuesSources) {
367364
collectors.add(valuesSource.apply(ctx));
368365
}
369-
return new MultiTermsValuesSourceCollector() {
370-
@Override
371-
public List<BytesRef> apply(int doc) throws IOException {
372-
List<List<TermValue<?>>> collectedValues = new ArrayList<>();
373-
for (InternalValuesSourceCollector collector : collectors) {
374-
collectedValues.add(collector.apply(doc));
375-
}
376-
List<BytesRef> result = new ArrayList<>();
377-
scratch.seek(0);
378-
scratch.writeVInt(collectors.size()); // number of fields per composite key
379-
cartesianProduct(result, scratch, collectedValues, 0);
380-
return result;
381-
}
382-
383-
/**
384-
* Cartesian product using depth first search.
385-
*
386-
* <p>
387-
* Composite keys are encoded to a {@link BytesRef} in a format compatible with {@link StreamOutput::writeGenericValue},
388-
* but reuses the encoding of the shared prefixes from the previous levels to avoid wasteful work.
389-
*/
390-
private void cartesianProduct(
391-
List<BytesRef> compositeKeys,
392-
BytesStreamOutput scratch,
393-
List<List<TermValue<?>>> collectedValues,
394-
int index
395-
) throws IOException {
396-
if (collectedValues.size() == index) {
397-
compositeKeys.add(BytesRef.deepCopyOf(scratch.bytes().toBytesRef()));
398-
return;
399-
}
400-
401-
long position = scratch.position();
402-
for (TermValue<?> value : collectedValues.get(index)) {
403-
value.writeTo(scratch); // encode the value
404-
cartesianProduct(compositeKeys, scratch, collectedValues, index + 1); // dfs
405-
scratch.seek(position); // backtrack
406-
}
407-
}
408-
};
366+
return new MultiValuesSourceCollectorImpl(collectors, scratch, bucketOrds, aggregator, sub);
409367
}
410368

411369
@Override
@@ -414,6 +372,74 @@ public void close() {
414372
}
415373
}
416374

375+
static class MultiValuesSourceCollectorImpl implements MultiTermsValuesSourceCollector {
376+
377+
private final List<InternalValuesSourceCollector> collectors;
378+
private final BytesStreamOutput scratch;
379+
private final BytesKeyedBucketOrds bucketOrds;
380+
private final BucketsAggregator aggregator;
381+
private final LeafBucketCollector sub;
382+
383+
private final boolean collectViaAggregator;
384+
385+
public MultiValuesSourceCollectorImpl(
386+
List<InternalValuesSourceCollector> collectors,
387+
BytesStreamOutput scratch,
388+
BytesKeyedBucketOrds bucketOrds,
389+
BucketsAggregator aggregator,
390+
LeafBucketCollector sub
391+
) {
392+
this.collectors = collectors;
393+
this.scratch = scratch;
394+
this.bucketOrds = bucketOrds;
395+
this.aggregator = aggregator;
396+
this.sub = sub;
397+
this.collectViaAggregator = aggregator != null && sub != null;
398+
}
399+
400+
@Override
401+
public void apply(int doc, long owningBucketOrd) throws IOException {
402+
List<List<TermValue<?>>> collectedValues = new ArrayList<>();
403+
for (InternalValuesSourceCollector collector : collectors) {
404+
collectedValues.add(collector.apply(doc));
405+
}
406+
scratch.seek(0);
407+
scratch.writeVInt(collectors.size()); // number of fields per composite key
408+
cartesianProductRecursive(collectedValues, 0, owningBucketOrd, doc);
409+
}
410+
411+
/**
412+
* Cartesian product using depth first search.
413+
*/
414+
private void cartesianProductRecursive(List<List<TermValue<?>>> collectedValues, int index, long owningBucketOrd, int doc)
415+
throws IOException {
416+
if (collectedValues.size() == index) {
417+
// Avoid performing a deep copy of the composite key
418+
long bucketOrd = bucketOrds.add(owningBucketOrd, scratch.bytes().toBytesRef());
419+
if (collectViaAggregator) {
420+
if (bucketOrd < 0) {
421+
bucketOrd = -1 - bucketOrd;
422+
aggregator.collectExistingBucket(sub, doc, bucketOrd);
423+
} else {
424+
aggregator.collectBucket(sub, doc, bucketOrd);
425+
}
426+
}
427+
return;
428+
}
429+
430+
long position = scratch.position();
431+
List<TermValue<?>> values = collectedValues.get(index);
432+
int numIterations = values.size();
433+
for (int i = 0; i < numIterations; i++) {
434+
TermValue<?> value = values.get(i);
435+
value.writeTo(scratch); // encode the value
436+
cartesianProductRecursive(collectedValues, index + 1, owningBucketOrd, doc); // dfs
437+
scratch.seek(position); // backtrack
438+
}
439+
}
440+
441+
}
442+
417443
/**
418444
* Factory for construct {@link InternalValuesSource}.
419445
*
@@ -441,9 +467,13 @@ static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, Include
441467
if (i > 0 && bytes.equals(previous)) {
442468
continue;
443469
}
444-
BytesRef copy = BytesRef.deepCopyOf(bytes);
445-
termValues.add(TermValue.of(copy));
446-
previous = copy;
470+
if (valuesCount > 1) {
471+
BytesRef copy = BytesRef.deepCopyOf(bytes);
472+
termValues.add(TermValue.of(copy));
473+
previous = copy;
474+
} else {
475+
termValues.add(TermValue.of(bytes));
476+
}
447477
}
448478
return termValues;
449479
};

0 commit comments

Comments
 (0)