Skip to content

[7.x backport] Add more flexibility to MovingFunction window alignment #45159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions docs/reference/aggregations/pipeline/movfn-aggregation.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ A `moving_fn` aggregation looks like this in isolation:
--------------------------------------------------
// NOTCONSOLE

[[moving-avg-params]]
.`moving_avg` Parameters
[[moving-fn-params]]
.`moving_fn` Parameters
[options="header"]
|===
|Parameter Name |Description |Required |Default Value
|`buckets_path` |Path to the metric of interest (see <<buckets-path-syntax, `buckets_path` Syntax>> for more details |Required |
|`window` |The size of window to "slide" across the histogram. |Required |
|`script` |The script that should be executed on each window of data |Required |
|`shift` |<<shift-parameter, Shift>> of window position. |Optional | 0
|===

`moving_fn` aggregations must be embedded inside of a `histogram` or `date_histogram` aggregation. They can be
Expand Down Expand Up @@ -169,6 +170,18 @@ POST /_search
// CONSOLE
// TEST[setup:sales]

[[shift-parameter]]
==== shift parameter

By default (with `shift = 0`), the window that is offered for calculation is the last `n` values excluding the current bucket.
Increasing `shift` by 1 moves starting window position by `1` to the right.

- To include current bucket to the window, use `shift = 1`.
- For center alignment (`n / 2` values before and after the current bucket), use `shift = window / 2`.
- For right alignment (`n` values after the current bucket), use `shift = window`.

If either of window edges moves outside the borders of data series, the window shrinks to include available values only.

==== Pre-built Functions

For convenience, a number of functions have been prebuilt and are available inside the `moving_fn` script context:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.search.aggregations.pipeline;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -48,12 +49,14 @@
public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<MovFnPipelineAggregationBuilder> {
public static final String NAME = "moving_fn";
private static final ParseField WINDOW = new ParseField("window");
private static final ParseField SHIFT = new ParseField("shift");

private final Script script;
private final String bucketsPathString;
private String format = null;
private GapPolicy gapPolicy = GapPolicy.SKIP;
private int window;
private int shift;

private static final Function<String, ConstructingObjectParser<MovFnPipelineAggregationBuilder, Void>> PARSER
= name -> {
Expand All @@ -68,6 +71,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
(p, c) -> Script.parse(p), Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING);
parser.declareInt(ConstructingObjectParser.constructorArg(), WINDOW);

parser.declareInt(MovFnPipelineAggregationBuilder::setShift, SHIFT);
parser.declareString(MovFnPipelineAggregationBuilder::format, FORMAT);
parser.declareField(MovFnPipelineAggregationBuilder::gapPolicy, p -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
Expand Down Expand Up @@ -97,6 +101,11 @@ public MovFnPipelineAggregationBuilder(StreamInput in) throws IOException {
format = in.readOptionalString();
gapPolicy = GapPolicy.readFrom(in);
window = in.readInt();
if (in.getVersion().onOrAfter(Version.V_7_4_0)) {
shift = in.readInt();
} else {
shift = 0;
}
}

@Override
Expand All @@ -106,6 +115,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalString(format);
gapPolicy.writeTo(out);
out.writeInt(window);
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
out.writeInt(shift);
}
}

/**
Expand Down Expand Up @@ -168,9 +180,13 @@ public void setWindow(int window) {
this.window = window;
}

public void setShift(int shift) {
this.shift = shift;
}

@Override
public void doValidate(AggregatorFactory parent, Collection<AggregationBuilder> aggFactories,
Collection<PipelineAggregationBuilder> pipelineAggregatoractories) {
Collection<PipelineAggregationBuilder> pipelineAggregatorFactories) {
if (window <= 0) {
throw new IllegalArgumentException("[" + WINDOW.getPreferredName() + "] must be a positive, non-zero integer.");
}
Expand All @@ -180,7 +196,7 @@ public void doValidate(AggregatorFactory parent, Collection<AggregationBuilder>

@Override
protected PipelineAggregator createInternal(Map<String, Object> metaData) {
return new MovFnPipelineAggregator(name, bucketsPathString, script, window, formatter(), gapPolicy, metaData);
return new MovFnPipelineAggregator(name, bucketsPathString, script, window, shift, formatter(), gapPolicy, metaData);
}

@Override
Expand All @@ -192,6 +208,7 @@ protected XContentBuilder internalXContent(XContentBuilder builder, Params param
}
builder.field(GAP_POLICY.getPreferredName(), gapPolicy.getName());
builder.field(WINDOW.getPreferredName(), window);
builder.field(SHIFT.getPreferredName(), shift);
return builder;
}

Expand Down Expand Up @@ -225,7 +242,7 @@ protected boolean overrideBucketsPath() {

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), bucketsPathString, script, format, gapPolicy, window);
return Objects.hash(super.hashCode(), bucketsPathString, script, format, gapPolicy, window, shift);
}

@Override
Expand All @@ -238,7 +255,8 @@ public boolean equals(Object obj) {
&& Objects.equals(script, other.script)
&& Objects.equals(format, other.format)
&& Objects.equals(gapPolicy, other.gapPolicy)
&& Objects.equals(window, other.window);
&& Objects.equals(window, other.window)
&& Objects.equals(shift, other.shift);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.elasticsearch.search.aggregations.pipeline;

import org.elasticsearch.common.collect.EvictingQueue;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.script.Script;
Expand Down Expand Up @@ -63,15 +63,17 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
private final Script script;
private final String bucketsPath;
private final int window;
private final int shift;

MovFnPipelineAggregator(String name, String bucketsPath, Script script, int window, DocValueFormat formatter,
MovFnPipelineAggregator(String name, String bucketsPath, Script script, int window, int shift, DocValueFormat formatter,
BucketHelpers.GapPolicy gapPolicy, Map<String, Object> metadata) {
super(name, new String[]{bucketsPath}, metadata);
this.bucketsPath = bucketsPath;
this.script = script;
this.formatter = formatter;
this.gapPolicy = gapPolicy;
this.window = window;
this.shift = shift;
}

public MovFnPipelineAggregator(StreamInput in) throws IOException {
Expand All @@ -81,6 +83,11 @@ public MovFnPipelineAggregator(StreamInput in) throws IOException {
gapPolicy = BucketHelpers.GapPolicy.readFrom(in);
bucketsPath = in.readString();
window = in.readInt();
if (in.getVersion().onOrAfter(Version.V_7_4_0)) {
shift = in.readInt();
} else {
shift = 0;
}
}

@Override
Expand All @@ -90,6 +97,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
gapPolicy.writeTo(out);
out.writeString(bucketsPath);
out.writeInt(window);
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
out.writeInt(shift);
}
}

@Override
Expand All @@ -106,7 +116,6 @@ public InternalAggregation reduce(InternalAggregation aggregation, InternalAggre
HistogramFactory factory = (HistogramFactory) histo;

List<MultiBucketsAggregation.Bucket> newBuckets = new ArrayList<>();
EvictingQueue<Double> values = new EvictingQueue<>(this.window);

// Initialize the script
MovingFunctionScript.Factory scriptFactory = reduceContext.scriptService().compile(script, MovingFunctionScript.CONTEXT);
Expand All @@ -117,30 +126,53 @@ public InternalAggregation reduce(InternalAggregation aggregation, InternalAggre

MovingFunctionScript executableScript = scriptFactory.newInstance();

List<Double> values = buckets.stream()
.map(b -> resolveBucketValue(histo, b, bucketsPaths()[0], gapPolicy))
.filter(v -> v != null && v.isNaN() == false)
.collect(Collectors.toList());

int index = 0;
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy);

// Default is to reuse existing bucket. Simplifies the rest of the logic,
// since we only change newBucket if we can add to it
MultiBucketsAggregation.Bucket newBucket = bucket;

if (thisBucketValue != null && thisBucketValue.equals(Double.NaN) == false) {
if (thisBucketValue != null && thisBucketValue.isNaN() == false) {

// The custom context mandates that the script returns a double (not Double) so we
// don't need null checks, etc.
double movavg = executableScript.execute(vars, values.stream().mapToDouble(Double::doubleValue).toArray());
int fromIndex = clamp(index - window + shift, values);
int toIndex = clamp(index + shift, values);
double movavg = executableScript.execute(
vars,
values.subList(fromIndex, toIndex).stream()
.mapToDouble(Double::doubleValue)
.toArray()
);

List<InternalAggregation> aggs = StreamSupport
.stream(bucket.getAggregations().spliterator(), false)
.map(InternalAggregation.class::cast)
.collect(Collectors.toList());
aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<>(), metaData()));
newBucket = factory.createBucket(factory.getKey(bucket), bucket.getDocCount(), new InternalAggregations(aggs));
values.offer(thisBucketValue);
index++;
}
newBuckets.add(newBucket);
}

return factory.createAggregation(newBuckets);
}

private int clamp(int index, List<Double> list) {
if (index < 0) {
return 0;
}
if (index > list.size()) {
return list.size();
}
return index;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.pipeline.MovFnPipelineAggregationBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;

import java.io.IOException;
Expand All @@ -31,7 +30,14 @@ public class MovFnPipelineAggregationBuilderSerializationTests extends AbstractS

@Override
protected MovFnPipelineAggregationBuilder createTestInstance() {
return new MovFnPipelineAggregationBuilder(randomAlphaOfLength(10), "foo", new Script("foo"), randomIntBetween(1, 10));
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder(
randomAlphaOfLength(10),
"foo",
new Script("foo"),
randomIntBetween(1, 10)
);
builder.setShift(randomIntBetween(1, 10));
return builder;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
Expand All @@ -79,25 +80,42 @@ public class MovFnUnitTests extends AggregatorTestCase {
private static final List<Integer> datasetValues = Arrays.asList(1,2,3,4,5,6,7,8,9,10);

public void testMatchAllDocs() throws IOException {
Query query = new MatchAllDocsQuery();
check(0, Arrays.asList(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
}

public void testShift() throws IOException {
check(1, Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0));
check(5, Arrays.asList(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN));
check(-5, Arrays.asList(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0));
}

public void testWideWindow() throws IOException {
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 100);
builder.setShift(50);
check(builder, script, Arrays.asList(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0));
}

private void check(int shift, List<Double> expected) throws IOException {
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3);
builder.setShift(shift);
check(builder, script, expected);
}

private void check(MovFnPipelineAggregationBuilder builder, Script script, List<Double> expected) throws IOException {
Query query = new MatchAllDocsQuery();
DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo");
aggBuilder.calendarInterval(DateHistogramInterval.DAY).field(DATE_FIELD);
aggBuilder.subAggregation(new AvgAggregationBuilder("avg").field(VALUE_FIELD));
aggBuilder.subAggregation(new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3));
aggBuilder.subAggregation(builder);

executeTestCase(query, aggBuilder, histogram -> {
assertEquals(10, histogram.getBuckets().size());
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
for (int i = 0; i < buckets.size(); i++) {
if (i == 0) {
assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("mov_fn"))).value(), equalTo(Double.NaN));
} else {
assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("mov_fn"))).value(), equalTo(((double) i)));
}

}
List<Double> actual = buckets.stream()
.map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value())
.collect(Collectors.toList());
assertThat(actual, equalTo(expected));
}, 1000, script);
}

Expand Down