Skip to content

[8.19] Backport ES|QL sample processing command #129617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125570.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125570
summary: ES|QL random sampling
area: Machine Learning
type: feature
issues: []
2 changes: 2 additions & 0 deletions docs/reference/esql/esql-commands.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ endif::[]
* experimental:[] <<esql-lookup-join>>
* experimental:[] <<esql-mv_expand>>
* <<esql-rename>>
* experimental:[] <<esql-sample>>
* <<esql-sort>>
* <<esql-stats-by>>
* <<esql-where>>
Expand All @@ -70,6 +71,7 @@ include::processing-commands/limit.asciidoc[]
include::processing-commands/lookup.asciidoc[]
include::processing-commands/mv_expand.asciidoc[]
include::processing-commands/rename.asciidoc[]
include::processing-commands/sample.asciidoc[]
include::processing-commands/sort.asciidoc[]
include::processing-commands/stats.asciidoc[]
include::processing-commands/where.asciidoc[]
30 changes: 30 additions & 0 deletions docs/reference/esql/processing-commands/sample.asciidoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[discrete]
[[esql-sample]]
=== `SAMPLE`

preview::[]

The `SAMPLE` command samples a fraction of the table rows.

**Syntax**

[source,esql]
----
SAMPLE probability
----

*Parameters*

`probability`::
The probability that a row is included in the sample. The value must be between 0 and 1, exclusive.

*Example*

[source.merge.styled,esql]
----
include::{esql-specs}/sample.csv-spec[tag=sampleForDocs]
----
[%header.monospaced.styled,format=dsv,separator=|]
|===
include::{esql-specs}/sample.csv-spec[tag=sampleForDocs-result]
|===
3 changes: 2 additions & 1 deletion docs/reference/rest-api/usage.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ GET /_xpack/usage
"lookup_join" : 0,
"change_point" : 0,
"completion": 0,
"rerank": 0
"rerank": 0,
"sample": 0
},
"queries" : {
"rest" : {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_53);
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_54);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
import org.elasticsearch.search.aggregations.bucket.sampler.UnmappedSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.InternalRandomSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplerAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.DoubleTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongRareTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
Expand Down Expand Up @@ -1209,6 +1210,9 @@ private void registerQueryParsers(List<SearchPlugin> plugins) {
registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> {
throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly");
}));
registerQuery(
new QuerySpec<>(RandomSamplingQueryBuilder.NAME, RandomSamplingQueryBuilder::new, RandomSamplingQueryBuilder::fromXContent)
);

registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,34 @@ public final class RandomSamplingQuery extends Query {
* can be generated
*/
public RandomSamplingQuery(double p, int seed, int hash) {
if (p <= 0.0 || p >= 1.0) {
throw new IllegalArgumentException("RandomSampling probability must be between 0.0 and 1.0, was [" + p + "]");
}
checkProbabilityRange(p);
this.p = p;
this.seed = seed;
this.hash = hash;
}

/**
* Verifies that the probability is within the (0.0, 1.0) range.
* @throws IllegalArgumentException in case of an invalid probability.
*/
public static void checkProbabilityRange(double p) throws IllegalArgumentException {
if (p <= 0.0 || p >= 1.0) {
throw new IllegalArgumentException("RandomSampling probability must be strictly between 0.0 and 1.0, was [" + p + "]");
}
}

public double probability() {
return p;
}

public int seed() {
return seed;
}

public int hash() {
return hash;
}

@Override
public String toString(String field) {
return "RandomSamplingQuery{" + "p=" + p + ", seed=" + seed + ", hash=" + hash + '}';
Expand Down Expand Up @@ -97,13 +117,13 @@ public void visit(QueryVisitor visitor) {
/**
* A DocIDSetIter that skips a geometrically random number of documents
*/
static class RandomSamplingIterator extends DocIdSetIterator {
public static class RandomSamplingIterator extends DocIdSetIterator {
private final int maxDoc;
private final double p;
private final FastGeometric distribution;
private int doc = -1;

RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
public RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
this.maxDoc = maxDoc;
this.p = p;
this.distribution = new FastGeometric(rng, p);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.aggregations.bucket.sampler.random;

import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery.checkProbabilityRange;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class RandomSamplingQueryBuilder extends AbstractQueryBuilder<RandomSamplingQueryBuilder> {

public static final String NAME = "random_sampling";
static final ParseField PROBABILITY = new ParseField("query");
static final ParseField SEED = new ParseField("seed");
static final ParseField HASH = new ParseField("hash");

private final double probability;
private int seed = Randomness.get().nextInt();
private int hash = 0;

public RandomSamplingQueryBuilder(double probability) {
checkProbabilityRange(probability);
this.probability = probability;
}

public RandomSamplingQueryBuilder seed(int seed) {
checkProbabilityRange(probability);
this.seed = seed;
return this;
}

public RandomSamplingQueryBuilder(StreamInput in) throws IOException {
super(in);
this.probability = in.readDouble();
this.seed = in.readInt();
this.hash = in.readInt();
}

public RandomSamplingQueryBuilder hash(Integer hash) {
this.hash = hash;
return this;
}

public double probability() {
return probability;
}

public int seed() {
return seed;
}

public int hash() {
return hash;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeDouble(probability);
out.writeInt(seed);
out.writeInt(hash);
}

@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.field(PROBABILITY.getPreferredName(), probability);
builder.field(SEED.getPreferredName(), seed);
builder.field(HASH.getPreferredName(), hash);
builder.endObject();
}

private static final ConstructingObjectParser<RandomSamplingQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
NAME,
false,
args -> {
var randomSamplingQueryBuilder = new RandomSamplingQueryBuilder((double) args[0]);
if (args[1] != null) {
randomSamplingQueryBuilder.seed((int) args[1]);
}
if (args[2] != null) {
randomSamplingQueryBuilder.hash((int) args[2]);
}
return randomSamplingQueryBuilder;
}
);

static {
PARSER.declareDouble(constructorArg(), PROBABILITY);
PARSER.declareInt(optionalConstructorArg(), SEED);
PARSER.declareInt(optionalConstructorArg(), HASH);
}

public static RandomSamplingQueryBuilder fromXContent(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}

@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
return new RandomSamplingQuery(probability, seed, hash);
}

@Override
protected boolean doEquals(RandomSamplingQueryBuilder other) {
return probability == other.probability && seed == other.seed && hash == other.hash;
}

@Override
protected int doHashCode() {
return Objects.hash(probability, seed, hash);
}

/**
* Returns the name of the writeable object
*/
@Override
public String getWriteableName() {
return NAME;
}

/**
* The minimal version of the recipient this object can be sent to
*/
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.RANDOM_SAMPLER_QUERY_BUILDER_8_19;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ public CheckedBiConsumer<ShardSearchRequest, StreamOutput, IOException> getReque
"range",
"regexp",
"knn_score_doc",
"random_sampling",
"script",
"script_score",
"simple_query_string",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.aggregations.bucket.sampler.random;

import org.apache.lucene.search.Query;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.test.AbstractQueryTestCase;
import org.elasticsearch.xcontent.XContentParseException;

import java.io.IOException;

import static org.hamcrest.Matchers.equalTo;

public class RandomSamplingQueryBuilderTests extends AbstractQueryTestCase<RandomSamplingQueryBuilder> {

@Override
protected RandomSamplingQueryBuilder doCreateTestQueryBuilder() {
double probability = randomDoubleBetween(0.0, 1.0, false);
var builder = new RandomSamplingQueryBuilder(probability);
if (randomBoolean()) {
builder.seed(randomInt());
}
if (randomBoolean()) {
builder.hash(randomInt());
}
return builder;
}

@Override
protected void doAssertLuceneQuery(RandomSamplingQueryBuilder queryBuilder, Query query, SearchExecutionContext context)
throws IOException {
var rsQuery = asInstanceOf(RandomSamplingQuery.class, query);
assertThat(rsQuery.probability(), equalTo(queryBuilder.probability()));
assertThat(rsQuery.seed(), equalTo(queryBuilder.seed()));
assertThat(rsQuery.hash(), equalTo(queryBuilder.hash()));
}

@Override
protected boolean supportsBoost() {
return false;
}

@Override
protected boolean supportsQueryName() {
return false;
}

@Override
public void testUnknownField() {
var json = "{ \""
+ RandomSamplingQueryBuilder.NAME
+ "\" : {\"bogusField\" : \"someValue\", \""
+ RandomSamplingQueryBuilder.PROBABILITY.getPreferredName()
+ "\" : \""
+ randomBoolean()
+ "\", \""
+ RandomSamplingQueryBuilder.SEED.getPreferredName()
+ "\" : \""
+ randomInt()
+ "\", \""
+ RandomSamplingQueryBuilder.HASH.getPreferredName()
+ "\" : \""
+ randomInt()
+ "\" } }";
var e = expectThrows(XContentParseException.class, () -> parseQuery(json));
assertTrue(e.getMessage().contains("bogusField"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ public static boolean isSupported(String name) {
return ATTRIBUTES_MAP.containsKey(name);
}

public static boolean isScoreAttribute(Expression a) {
return a instanceof MetadataAttribute ma && ma.name().equals(SCORE);
}

@Override
@SuppressWarnings("checkstyle:EqualsHashCode")// equals is implemented in parent. See innerEquals instead
public int hashCode() {
Expand Down
Loading