Skip to content

Commit

Permalink
add topn query guardrails
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuglu-netflix committed Nov 1, 2024
1 parent d5bb7de commit dc7d238
Show file tree
Hide file tree
Showing 20 changed files with 377 additions and 13 deletions.
7 changes: 4 additions & 3 deletions docs/configuration/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2147,9 +2147,10 @@ context). If query does have `maxQueuedBytes` in the context, then that value is

### TopN query config

|Property|Description|Default|
|--------|-----------|-------|
|`druid.query.topN.minTopNThreshold`|See [TopN Aliasing](../querying/topnquery.md#aliasing) for details.|1000|
|Property| Description | Default |
|--------|-------------------------------------------------------------------------------|---------|
|`druid.query.topN.minTopNThreshold`| See [TopN Aliasing](../querying/topnquery.md#aliasing) for details. | 1000 |
|`druid.query.topN.maxTopNAggregatorHeapSizeBytes`| The maximum amount of aggregator heap bytes a given segment runner can acrue. | 10MB |

### Search query config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.topn.TopNQuery;
import org.apache.druid.query.topn.TopNQueryBuilder;
import org.apache.druid.query.topn.TopNQueryConfig;
import org.apache.druid.query.topn.TopNQueryEngine;
import org.apache.druid.query.topn.TopNResultValue;
import org.apache.druid.segment.IncrementalIndexSegment;
Expand Down Expand Up @@ -133,6 +134,7 @@ public void testTopNWithDistinctCountAgg() throws Exception
final Iterable<Result<TopNResultValue>> results =
engine.query(
query,
new TopNQueryConfig(),
new IncrementalIndexSegment(index, SegmentId.dummy(QueryRunnerTestHelper.DATA_SOURCE)),
null
).toList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public class QueryContexts
public static final String SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY = "serializeDateTimeAsLongInner";
public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit";
public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold";
public static final String MAX_TOP_N_AGGREGATOR_HEAP_SIZE_BYTES = "maxTopNAggregatorHeapSizeBytes";
public static final String CATALOG_VALIDATION_ENABLED = "catalogValidationEnabled";

// projection context keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@
public abstract class BaseTopNAlgorithm<DimValSelector, DimValAggregateStore, Parameters extends TopNParams>
implements TopNAlgorithm<DimValSelector, Parameters>
{
public static Aggregator[] makeAggregators(Cursor cursor, List<AggregatorFactory> aggregatorSpecs)

public static Aggregator[] makeAggregators(TopNQuery query, Cursor cursor)
{
query.getAggregatorHelper().addAggregatorMemory();
final List<AggregatorFactory> aggregatorSpecs = query.getAggregatorSpecs();
Aggregator[] aggregators = new Aggregator[aggregatorSpecs.size()];
int aggregatorIndex = 0;
for (AggregatorFactory spec : aggregatorSpecs) {
Expand All @@ -52,8 +55,10 @@ public static Aggregator[] makeAggregators(Cursor cursor, List<AggregatorFactory
return aggregators;
}

protected static BufferAggregator[] makeBufferAggregators(Cursor cursor, List<AggregatorFactory> aggregatorSpecs)
protected static BufferAggregator[] makeBufferAggregators(TopNQuery query, Cursor cursor)
{
query.getAggregatorHelper().addAggregatorMemory();
final List<AggregatorFactory> aggregatorSpecs = query.getAggregatorSpecs();
BufferAggregator[] aggregators = new BufferAggregator[aggregatorSpecs.size()];
int aggregatorIndex = 0;
for (AggregatorFactory spec : aggregatorSpecs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ public int[] build()
resultsBuf.clear();

final int numBytesToWorkWith = resultsBuf.remaining();

query.getAggregatorHelper().addAggregatorMemory();
final int[] aggregatorSizes = new int[query.getAggregatorSpecs().size()];
int numBytesPerRecord = 0;

Expand Down Expand Up @@ -329,7 +331,7 @@ protected int[] updateDimValSelector(int[] dimValSelector, int numProcessed, int
@Override
protected BufferAggregator[] makeDimValAggregateStore(PooledTopNParams params)
{
return makeBufferAggregators(params.getCursor(), query.getAggregatorSpecs());
return makeBufferAggregators(query, params.getCursor());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ protected long scanAndAggregate(

Aggregator[] theAggregators = aggregatesStore.computeIfAbsent(
key,
k -> makeAggregators(cursor, query.getAggregatorSpecs())
k -> makeAggregators(query, cursor)
);

for (Aggregator aggregator : theAggregators) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.query.topn;

import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.ResourceLimitExceededException;

import java.util.concurrent.atomic.AtomicLong;

public class TopNAggregatorResourceHelper
{
public static class Config {
public final long maxAggregatorHeapSize;
public Config(final long maxAggregatorHeapSize) {
this.maxAggregatorHeapSize = maxAggregatorHeapSize;
}
}

private final Config config;
private final long newAggregatorEstimatedMemorySize;
private final AtomicLong used = new AtomicLong(0);

TopNAggregatorResourceHelper(final long newAggregatorEstimatedMemorySize, final Config config) {
this.newAggregatorEstimatedMemorySize = newAggregatorEstimatedMemorySize;
this.config = config;
}

public void addAggregatorMemory() {
final long newTotal = used.addAndGet(newAggregatorEstimatedMemorySize);
if (newTotal > config.maxAggregatorHeapSize){
throw new ResourceLimitExceededException(StringUtils.format("Query ran out of memory. Maximum allowed bytes=[%d], Hit bytes=[%d]", config.maxAggregatorHeapSize, newTotal));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.druid.query.PerSegmentQueryOptimizationContext;
import org.apache.druid.query.Queries;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
Expand Down Expand Up @@ -60,6 +61,7 @@ public class TopNQuery extends BaseQuery<Result<TopNResultValue>>
private final DimFilter dimFilter;
private final List<AggregatorFactory> aggregatorSpecs;
private final List<PostAggregator> postAggregatorSpecs;
private TopNAggregatorResourceHelper aggregatorHelper;

@JsonCreator
public TopNQuery(
Expand Down Expand Up @@ -97,9 +99,18 @@ public TopNQuery(
: postAggregatorSpecs
);


final long expectedAllocBytes = aggregatorSpecs.stream().mapToLong(AggregatorFactory::getMaxIntermediateSizeWithNulls).sum();

Check warning

Code scanning / CodeQL

Dereferenced variable may be null Warning

Variable
aggregatorSpecs
may be null at this access as suggested by
this
null guard.
final long maxAggregatorHeapSizeBytes = this.context().getLong(QueryContexts.MAX_TOP_N_AGGREGATOR_HEAP_SIZE_BYTES, TopNQueryConfig.DEFAULT_MAX_AGGREGATOR_HEAP_SIZE_BYTES);
this.aggregatorHelper = new TopNAggregatorResourceHelper(expectedAllocBytes, new TopNAggregatorResourceHelper.Config(maxAggregatorHeapSizeBytes));

topNMetricSpec.verifyPreconditions(this.aggregatorSpecs, this.postAggregatorSpecs);
}

public TopNAggregatorResourceHelper getAggregatorHelper() {
return aggregatorHelper;
}

@Override
public boolean hasFilters()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
public class TopNQueryConfig
{
public static final int DEFAULT_MIN_TOPN_THRESHOLD = 1000;
public static final long DEFAULT_MAX_AGGREGATOR_HEAP_SIZE_BYTES = 10 * (2 << 20); // 10mb

@JsonProperty
@Min(1)
Expand All @@ -37,4 +38,13 @@ public int getMinTopNThreshold()
{
return minTopNThreshold;
}

@JsonProperty
@Min(0)
private long maxTopNAggregatorHeapSizeBytes = DEFAULT_MAX_AGGREGATOR_HEAP_SIZE_BYTES;

public long getMaxTopNAggregatorHeapSizeBytes()
{
return maxTopNAggregatorHeapSizeBytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@

import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.collections.NonBlockingPool;
import org.apache.druid.collections.ResourceHolder;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.ColumnSelectorPlus;
import org.apache.druid.query.CursorGranularizer;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.AggregatorFactory;
Expand Down Expand Up @@ -75,6 +77,7 @@ public TopNQueryEngine(NonBlockingPool<ByteBuffer> bufferPool)
*/
public Sequence<Result<TopNResultValue>> query(
TopNQuery query,
TopNQueryConfig config,
final Segment segment,
@Nullable final TopNQueryMetrics queryMetrics
)
Expand All @@ -86,6 +89,10 @@ public Sequence<Result<TopNResultValue>> query(
);
}

if (!query.context().containsKey(QueryContexts.MAX_TOP_N_AGGREGATOR_HEAP_SIZE_BYTES)){
query = query.withOverriddenContext(ImmutableMap.of(QueryContexts.MAX_TOP_N_AGGREGATOR_HEAP_SIZE_BYTES, config.getMaxTopNAggregatorHeapSizeBytes()));
}

final CursorBuildSpec buildSpec = makeCursorBuildSpec(query, queryMetrics);
final CursorHolder cursorHolder = cursorFactory.makeCursorHolder(buildSpec);
if (cursorHolder.isPreAggregated()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,11 @@ public Sequence<Object[]> resultsAsArrays(TopNQuery query, Sequence<Result<TopNR
);
}

public TopNQueryConfig getConfig()
{
return this.config;
}

/**
* This returns a single frame containing the rows of the topN query's results
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public Sequence<Result<TopNResultValue>> run(
TopNQuery query = (TopNQuery) input.getQuery();
return queryEngine.query(
query,
toolchest.getConfig(),
segment,
(TopNQueryMetrics) input.getQueryMetrics()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Aggregator[] getValueAggregators(
long key = Double.doubleToLongBits(selector.getDouble());
return aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Aggregator[] getValueAggregators(
int key = Float.floatToIntBits(selector.getFloat());
return aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Aggregator[] getValueAggregators(TopNQuery query, BaseLongColumnValueSelector se
long key = selector.getLong();
return aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public long scanAndAggregate(
while (!cursor.isDone()) {
if (hasNulls && selector.isNull()) {
if (nullValueAggregates == null) {
nullValueAggregates = BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs());
nullValueAggregates = BaseTopNAlgorithm.makeAggregators(query, cursor);
}
for (Aggregator aggregator : nullValueAggregates) {
aggregator.aggregate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ private long scanAndAggregateWithCardinalityKnown(
final Object key = dimensionValueConverter.apply(selector.lookupName(dimIndex));
aggs = aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
rowSelector[dimIndex] = aggs;
}
Expand Down Expand Up @@ -199,7 +199,7 @@ private long scanAndAggregateWithCardinalityUnknown(
final Object key = dimensionValueConverter.apply(selector.lookupName(dimIndex));
Aggregator[] aggs = aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
for (Aggregator aggregator : aggs) {
aggregator.aggregate();
Expand Down
Loading

0 comments on commit dc7d238

Please sign in to comment.