Skip to content

ES|QL random sampling #125570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d0cb643
Add a random sample commadn
bpintea Mar 3, 2025
ab4934c
[CI] Auto commit changes from spotless
elasticsearchmachine Mar 3, 2025
e03dc15
Add non-operator-related tests
bpintea Mar 5, 2025
450f450
Make seed parameter optional. Various fixes
bpintea Mar 5, 2025
7a741a4
[CI] Auto commit changes from spotless
elasticsearchmachine Mar 5, 2025
4c4fc1f
Make CsvTests more node-count-induced variation tollerant
bpintea Mar 6, 2025
1c56fb7
Simplify RandomSampleOperator
jan-elastic Mar 25, 2025
99a4a70
correct aggregations for random sampling
jan-elastic Mar 18, 2025
1598b80
Update docs/changelog/125570.yaml
jan-elastic Mar 25, 2025
6acf60b
don't correct multiple stats
jan-elastic Mar 26, 2025
23b5a3c
Refactor sample correction
jan-elastic Mar 26, 2025
b5cde58
Refactor sample correction once more
jan-elastic Mar 26, 2025
ec51f8c
spotless
jan-elastic Mar 26, 2025
67a0c44
fix random sample csv tests
jan-elastic Mar 26, 2025
878f09c
[CI] Auto commit changes from spotless
elasticsearchmachine Mar 27, 2025
4bfe076
Make `isSampledCorrected` field final
jan-elastic Mar 27, 2025
aedc2f4
fix StatementParserTests
jan-elastic Mar 27, 2025
f40efb3
make sample corrected constructor private
jan-elastic Mar 27, 2025
473217d
push down through where and sort
jan-elastic Mar 27, 2025
92949fa
rename RANDOM_SAMPLE -> SAMPLE
jan-elastic Mar 27, 2025
e449ae4
[CI] Auto commit changes from spotless
elasticsearchmachine Mar 27, 2025
83c0e07
SampleOperaetorTests + fix status
jan-elastic Mar 28, 2025
c98391a
Test accuracy of sampling operator
jan-elastic Mar 31, 2025
0e5d497
polish code
jan-elastic Mar 31, 2025
2e32a72
error on seed in sampling operator
jan-elastic Apr 1, 2025
cad45d7
Don't push sample correction through limit
jan-elastic Apr 8, 2025
01cc677
Don't push sample correction through mv_expand
jan-elastic Apr 8, 2025
011e612
CSV tests
jan-elastic Apr 3, 2025
27f4294
propagate multiple sample probabilities
jan-elastic Apr 9, 2025
875c2d5
REST test
jan-elastic Apr 9, 2025
eb1f728
enable all csv tests
jan-elastic Apr 9, 2025
ccc7179
Fix CSV test with sample+limit
jan-elastic Apr 10, 2025
a5ef3bd
add SampleBreaking interface
jan-elastic Apr 10, 2025
ac41e9c
comments
jan-elastic Apr 15, 2025
12df3c2
linkedlist -> arraydeque for efficiencyu
jan-elastic Apr 15, 2025
6051ec0
use samplebreaking in pushdown
jan-elastic Apr 15, 2025
c7cbf7e
different operator categories wrt sampling. Remove SampleBreaking int…
jan-elastic Apr 15, 2025
0a5085b
sample metrics
jan-elastic Apr 23, 2025
ba917cd
[CI] Auto commit changes from spotless
elasticsearchmachine Apr 23, 2025
29fbc39
fix esql metrics test
jan-elastic Apr 23, 2025
af95b37
delete unused file
jan-elastic Apr 23, 2025
f13a5c9
Merge branch 'main' into feat/random_sample
jan-elastic Apr 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125570.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125570
summary: ES|QL random sampling
area: Machine Learning
type: feature
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SYNONYMS_REFRESH_PARAM = def(9_060_0_00);
public static final TransportVersion DOC_FIELDS_AS_LIST = def(9_061_0_00);
public static final TransportVersion DENSE_VECTOR_OFF_HEAP_STATS = def(9_062_00_0);
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER = def(9_063_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
import org.elasticsearch.search.aggregations.bucket.sampler.UnmappedSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.InternalRandomSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplerAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.DoubleTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongRareTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
Expand Down Expand Up @@ -1186,6 +1187,9 @@ private void registerQueryParsers(List<SearchPlugin> plugins) {
registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> {
throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly");
}));
registerQuery(
new QuerySpec<>(RandomSamplingQueryBuilder.NAME, RandomSamplingQueryBuilder::new, RandomSamplingQueryBuilder::fromXContent)
);

registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,34 @@ public final class RandomSamplingQuery extends Query {
* can be generated
*/
public RandomSamplingQuery(double p, int seed, int hash) {
if (p <= 0.0 || p >= 1.0) {
throw new IllegalArgumentException("RandomSampling probability must be between 0.0 and 1.0, was [" + p + "]");
}
checkProbabilityRange(p);
this.p = p;
this.seed = seed;
this.hash = hash;
}

/**
* Verifies that the probability is within the (0.0, 1.0) range.
* @throws IllegalArgumentException in case of an invalid probability.
*/
public static void checkProbabilityRange(double p) throws IllegalArgumentException {
if (p <= 0.0 || p >= 1.0) {
throw new IllegalArgumentException("RandomSampling probability must be strictly between 0.0 and 1.0, was [" + p + "]");
}
}

public double probability() {
return p;
}

public int seed() {
return seed;
}

public int hash() {
return hash;
}

@Override
public String toString(String field) {
return "RandomSamplingQuery{" + "p=" + p + ", seed=" + seed + ", hash=" + hash + '}';
Expand Down Expand Up @@ -98,13 +118,13 @@ public void visit(QueryVisitor visitor) {
/**
* A DocIDSetIter that skips a geometrically random number of documents
*/
static class RandomSamplingIterator extends DocIdSetIterator {
public static class RandomSamplingIterator extends DocIdSetIterator {
private final int maxDoc;
private final double p;
private final FastGeometric distribution;
private int doc = -1;

RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
public RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
this.maxDoc = maxDoc;
this.p = p;
this.distribution = new FastGeometric(rng, p);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.aggregations.bucket.sampler.random;

import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery.checkProbabilityRange;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class RandomSamplingQueryBuilder extends AbstractQueryBuilder<RandomSamplingQueryBuilder> {

public static final String NAME = "random_sampling";
static final ParseField PROBABILITY = new ParseField("query");
static final ParseField SEED = new ParseField("seed");
static final ParseField HASH = new ParseField("hash");

private final double probability;
private int seed = Randomness.get().nextInt();
private int hash = 0;

public RandomSamplingQueryBuilder(double probability) {
checkProbabilityRange(probability);
this.probability = probability;
}

public RandomSamplingQueryBuilder seed(int seed) {
checkProbabilityRange(probability);
this.seed = seed;
return this;
}

public RandomSamplingQueryBuilder(StreamInput in) throws IOException {
super(in);
this.probability = in.readDouble();
this.seed = in.readInt();
this.hash = in.readInt();
}

public RandomSamplingQueryBuilder hash(Integer hash) {
this.hash = hash;
return this;
}

public double probability() {
return probability;
}

public int seed() {
return seed;
}

public int hash() {
return hash;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeDouble(probability);
out.writeInt(seed);
out.writeInt(hash);
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.field(PROBABILITY.getPreferredName(), probability);
builder.field(SEED.getPreferredName(), seed);
builder.field(HASH.getPreferredName(), hash);
builder.endObject();
}

private static final ConstructingObjectParser<RandomSamplingQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
NAME,
false,
args -> {
var randomSamplingQueryBuilder = new RandomSamplingQueryBuilder((double) args[0]);
if (args[1] != null) {
randomSamplingQueryBuilder.seed((int) args[1]);
}
if (args[2] != null) {
randomSamplingQueryBuilder.hash((int) args[2]);
}
return randomSamplingQueryBuilder;
}
);

static {
PARSER.declareDouble(constructorArg(), PROBABILITY);
PARSER.declareInt(optionalConstructorArg(), SEED);
PARSER.declareInt(optionalConstructorArg(), HASH);
}

public static RandomSamplingQueryBuilder fromXContent(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}

@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
return new RandomSamplingQuery(probability, seed, hash);
}

@Override
protected boolean doEquals(RandomSamplingQueryBuilder other) {
return probability == other.probability && seed == other.seed && hash == other.hash;
}

@Override
protected int doHashCode() {
return Objects.hash(probability, seed, hash);
}

/**
* Returns the name of the writeable object
*/
@Override
public String getWriteableName() {
return NAME;
}

/**
* The minimal version of the recipient this object can be sent to
*/
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.RANDOM_SAMPLER_QUERY_BUILDER;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ public CheckedBiConsumer<ShardSearchRequest, StreamOutput, IOException> getReque
"range",
"regexp",
"knn_score_doc",
"random_sampling",
"script",
"script_score",
"simple_query_string",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.aggregations.bucket.sampler.random;

import org.apache.lucene.search.Query;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.test.AbstractQueryTestCase;
import org.elasticsearch.xcontent.XContentParseException;

import java.io.IOException;

import static org.hamcrest.Matchers.equalTo;

public class RandomSamplingQueryBuilderTests extends AbstractQueryTestCase<RandomSamplingQueryBuilder> {

@Override
protected RandomSamplingQueryBuilder doCreateTestQueryBuilder() {
double probability = randomDoubleBetween(0.0, 1.0, false);
var builder = new RandomSamplingQueryBuilder(probability);
if (randomBoolean()) {
builder.seed(randomInt());
}
if (randomBoolean()) {
builder.hash(randomInt());
}
return builder;
}

@Override
protected void doAssertLuceneQuery(RandomSamplingQueryBuilder queryBuilder, Query query, SearchExecutionContext context)
throws IOException {
var rsQuery = asInstanceOf(RandomSamplingQuery.class, query);
assertThat(rsQuery.probability(), equalTo(queryBuilder.probability()));
assertThat(rsQuery.seed(), equalTo(queryBuilder.seed()));
assertThat(rsQuery.hash(), equalTo(queryBuilder.hash()));
}

@Override
protected boolean supportsBoost() {
return false;
}

@Override
protected boolean supportsQueryName() {
return false;
}

@Override
public void testUnknownField() {
var json = "{ \""
+ RandomSamplingQueryBuilder.NAME
+ "\" : {\"bogusField\" : \"someValue\", \""
+ RandomSamplingQueryBuilder.PROBABILITY.getPreferredName()
+ "\" : \""
+ randomBoolean()
+ "\", \""
+ RandomSamplingQueryBuilder.SEED.getPreferredName()
+ "\" : \""
+ randomInt()
+ "\", \""
+ RandomSamplingQueryBuilder.HASH.getPreferredName()
+ "\" : \""
+ randomInt()
+ "\" } }";
var e = expectThrows(XContentParseException.class, () -> parseQuery(json));
assertTrue(e.getMessage().contains("bogusField"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ public static boolean isSupported(String name) {
return ATTRIBUTES_MAP.containsKey(name);
}

public static boolean isScoreAttribute(Expression a) {
return a instanceof MetadataAttribute ma && ma.name().equals(SCORE);
}

@Override
@SuppressWarnings("checkstyle:EqualsHashCode")// equals is implemented in parent. See innerEquals instead
public int hashCode() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,21 @@ public Page projectBlocks(int[] blockMapping) {
}
}
}

public Page filter(int... positions) {
Block[] filteredBlocks = new Block[blocks.length];
boolean success = false;
try {
for (int i = 0; i < blocks.length; i++) {
filteredBlocks[i] = getBlock(i).filter(positions);
}
success = true;
} finally {
releaseBlocks();
if (success == false) {
Releasables.closeExpectNoException(filteredBlocks);
}
}
return new Page(filteredBlocks);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import org.elasticsearch.xpack.ml.aggs.changepoint.ChangePointDetector;
import org.elasticsearch.xpack.ml.aggs.changepoint.ChangeType;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;

/**
Expand Down Expand Up @@ -68,8 +68,8 @@ public ChangePointOperator(DriverContext driverContext, int channel, String sour
this.sourceColumn = sourceColumn;

finished = false;
inputPages = new LinkedList<>();
outputPages = new LinkedList<>();
inputPages = new ArrayDeque<>();
outputPages = new ArrayDeque<>();
warnings = null;
}

Expand Down
Loading
Loading