Skip to content

Commit ead4eb5

Browse files
Hoholpolyfractal
authored andcommitted
Add more flexibility to MovingFunction window alignment (#44360)
Introduce shift field to MovingFunction aggregation. By default, shift = 0. Behavior, in this case, is the same as before. Increasing shift by 1 moves starting window position by 1 to the right. To simply include current bucket to the window, use shift = 1 For center alignment (n/2 values before and after the current bucket), use shift = window / 2 For right alignment (n values after the current bucket), use shift = window.
1 parent 4dbba53 commit ead4eb5

File tree

5 files changed

+112
-25
lines changed

5 files changed

+112
-25
lines changed

docs/reference/aggregations/pipeline/movfn-aggregation.asciidoc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@ A `moving_fn` aggregation looks like this in isolation:
2424
--------------------------------------------------
2525
// NOTCONSOLE
2626

27-
[[moving-avg-params]]
28-
.`moving_avg` Parameters
27+
[[moving-fn-params]]
28+
.`moving_fn` Parameters
2929
[options="header"]
3030
|===
3131
|Parameter Name |Description |Required |Default Value
3232
|`buckets_path` |Path to the metric of interest (see <<buckets-path-syntax, `buckets_path` Syntax>> for more details |Required |
3333
|`window` |The size of window to "slide" across the histogram. |Required |
3434
|`script` |The script that should be executed on each window of data |Required |
35+
|`shift` |<<shift-parameter, Shift>> of window position. |Optional | 0
3536
|===
3637

3738
`moving_fn` aggregations must be embedded inside of a `histogram` or `date_histogram` aggregation. They can be
@@ -169,6 +170,18 @@ POST /_search
169170
// CONSOLE
170171
// TEST[setup:sales]
171172

173+
[[shift-parameter]]
174+
==== shift parameter
175+
176+
By default (with `shift = 0`), the window that is offered for calculation is the last `n` values excluding the current bucket.
177+
Increasing `shift` by 1 moves starting window position by `1` to the right.
178+
179+
- To include current bucket to the window, use `shift = 1`.
180+
- For center alignment (`n / 2` values before and after the current bucket), use `shift = window / 2`.
181+
- For right alignment (`n` values after the current bucket), use `shift = window`.
182+
183+
If either of window edges moves outside the borders of data series, the window shrinks to include available values only.
184+
172185
==== Pre-built Functions
173186

174187
For convenience, a number of functions have been prebuilt and are available inside the `moving_fn` script context:

server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilder.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.elasticsearch.search.aggregations.pipeline;
2121

22+
import org.elasticsearch.Version;
2223
import org.elasticsearch.common.ParseField;
2324
import org.elasticsearch.common.Strings;
2425
import org.elasticsearch.common.io.stream.StreamInput;
@@ -48,12 +49,14 @@
4849
public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<MovFnPipelineAggregationBuilder> {
4950
public static final String NAME = "moving_fn";
5051
private static final ParseField WINDOW = new ParseField("window");
52+
private static final ParseField SHIFT = new ParseField("shift");
5153

5254
private final Script script;
5355
private final String bucketsPathString;
5456
private String format = null;
5557
private GapPolicy gapPolicy = GapPolicy.SKIP;
5658
private int window;
59+
private int shift;
5760

5861
private static final Function<String, ConstructingObjectParser<MovFnPipelineAggregationBuilder, Void>> PARSER
5962
= name -> {
@@ -68,6 +71,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
6871
(p, c) -> Script.parse(p), Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING);
6972
parser.declareInt(ConstructingObjectParser.constructorArg(), WINDOW);
7073

74+
parser.declareInt(MovFnPipelineAggregationBuilder::setShift, SHIFT);
7175
parser.declareString(MovFnPipelineAggregationBuilder::format, FORMAT);
7276
parser.declareField(MovFnPipelineAggregationBuilder::gapPolicy, p -> {
7377
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
@@ -97,6 +101,11 @@ public MovFnPipelineAggregationBuilder(StreamInput in) throws IOException {
97101
format = in.readOptionalString();
98102
gapPolicy = GapPolicy.readFrom(in);
99103
window = in.readInt();
104+
if (in.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport
105+
shift = in.readInt();
106+
} else {
107+
shift = 0;
108+
}
100109
}
101110

102111
@Override
@@ -106,6 +115,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
106115
out.writeOptionalString(format);
107116
gapPolicy.writeTo(out);
108117
out.writeInt(window);
118+
if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport
119+
out.writeInt(shift);
120+
}
109121
}
110122

111123
/**
@@ -168,9 +180,13 @@ public void setWindow(int window) {
168180
this.window = window;
169181
}
170182

183+
public void setShift(int shift) {
184+
this.shift = shift;
185+
}
186+
171187
@Override
172188
public void doValidate(AggregatorFactory parent, Collection<AggregationBuilder> aggFactories,
173-
Collection<PipelineAggregationBuilder> pipelineAggregatoractories) {
189+
Collection<PipelineAggregationBuilder> pipelineAggregatorFactories) {
174190
if (window <= 0) {
175191
throw new IllegalArgumentException("[" + WINDOW.getPreferredName() + "] must be a positive, non-zero integer.");
176192
}
@@ -180,7 +196,7 @@ public void doValidate(AggregatorFactory parent, Collection<AggregationBuilder>
180196

181197
@Override
182198
protected PipelineAggregator createInternal(Map<String, Object> metaData) {
183-
return new MovFnPipelineAggregator(name, bucketsPathString, script, window, formatter(), gapPolicy, metaData);
199+
return new MovFnPipelineAggregator(name, bucketsPathString, script, window, shift, formatter(), gapPolicy, metaData);
184200
}
185201

186202
@Override
@@ -192,6 +208,7 @@ protected XContentBuilder internalXContent(XContentBuilder builder, Params param
192208
}
193209
builder.field(GAP_POLICY.getPreferredName(), gapPolicy.getName());
194210
builder.field(WINDOW.getPreferredName(), window);
211+
builder.field(SHIFT.getPreferredName(), shift);
195212
return builder;
196213
}
197214

@@ -225,7 +242,7 @@ protected boolean overrideBucketsPath() {
225242

226243
@Override
227244
public int hashCode() {
228-
return Objects.hash(super.hashCode(), bucketsPathString, script, format, gapPolicy, window);
245+
return Objects.hash(super.hashCode(), bucketsPathString, script, format, gapPolicy, window, shift);
229246
}
230247

231248
@Override
@@ -238,7 +255,8 @@ public boolean equals(Object obj) {
238255
&& Objects.equals(script, other.script)
239256
&& Objects.equals(format, other.format)
240257
&& Objects.equals(gapPolicy, other.gapPolicy)
241-
&& Objects.equals(window, other.window);
258+
&& Objects.equals(window, other.window)
259+
&& Objects.equals(shift, other.shift);
242260
}
243261

244262
@Override

server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregator.java

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.elasticsearch.search.aggregations.pipeline;
2121

22-
import org.elasticsearch.common.collect.EvictingQueue;
22+
import org.elasticsearch.Version;
2323
import org.elasticsearch.common.io.stream.StreamInput;
2424
import org.elasticsearch.common.io.stream.StreamOutput;
2525
import org.elasticsearch.script.Script;
@@ -63,15 +63,17 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
6363
private final Script script;
6464
private final String bucketsPath;
6565
private final int window;
66+
private final int shift;
6667

67-
MovFnPipelineAggregator(String name, String bucketsPath, Script script, int window, DocValueFormat formatter,
68+
MovFnPipelineAggregator(String name, String bucketsPath, Script script, int window, int shift, DocValueFormat formatter,
6869
BucketHelpers.GapPolicy gapPolicy, Map<String, Object> metadata) {
6970
super(name, new String[]{bucketsPath}, metadata);
7071
this.bucketsPath = bucketsPath;
7172
this.script = script;
7273
this.formatter = formatter;
7374
this.gapPolicy = gapPolicy;
7475
this.window = window;
76+
this.shift = shift;
7577
}
7678

7779
public MovFnPipelineAggregator(StreamInput in) throws IOException {
@@ -81,6 +83,11 @@ public MovFnPipelineAggregator(StreamInput in) throws IOException {
8183
gapPolicy = BucketHelpers.GapPolicy.readFrom(in);
8284
bucketsPath = in.readString();
8385
window = in.readInt();
86+
if (in.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport
87+
shift = in.readInt();
88+
} else {
89+
shift = 0;
90+
}
8491
}
8592

8693
@Override
@@ -90,6 +97,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
9097
gapPolicy.writeTo(out);
9198
out.writeString(bucketsPath);
9299
out.writeInt(window);
100+
if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport
101+
out.writeInt(shift);
102+
}
93103
}
94104

95105
@Override
@@ -106,7 +116,6 @@ public InternalAggregation reduce(InternalAggregation aggregation, InternalAggre
106116
HistogramFactory factory = (HistogramFactory) histo;
107117

108118
List<MultiBucketsAggregation.Bucket> newBuckets = new ArrayList<>();
109-
EvictingQueue<Double> values = new EvictingQueue<>(this.window);
110119

111120
// Initialize the script
112121
MovingFunctionScript.Factory scriptFactory = reduceContext.scriptService().compile(script, MovingFunctionScript.CONTEXT);
@@ -117,30 +126,53 @@ public InternalAggregation reduce(InternalAggregation aggregation, InternalAggre
117126

118127
MovingFunctionScript executableScript = scriptFactory.newInstance();
119128

129+
List<Double> values = buckets.stream()
130+
.map(b -> resolveBucketValue(histo, b, bucketsPaths()[0], gapPolicy))
131+
.filter(v -> v != null && v.isNaN() == false)
132+
.collect(Collectors.toList());
133+
134+
int index = 0;
120135
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
121136
Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy);
122137

123138
// Default is to reuse existing bucket. Simplifies the rest of the logic,
124139
// since we only change newBucket if we can add to it
125140
MultiBucketsAggregation.Bucket newBucket = bucket;
126141

127-
if (thisBucketValue != null && thisBucketValue.equals(Double.NaN) == false) {
142+
if (thisBucketValue != null && thisBucketValue.isNaN() == false) {
128143

129144
// The custom context mandates that the script returns a double (not Double) so we
130145
// don't need null checks, etc.
131-
double movavg = executableScript.execute(vars, values.stream().mapToDouble(Double::doubleValue).toArray());
146+
int fromIndex = clamp(index - window + shift, values);
147+
int toIndex = clamp(index + shift, values);
148+
double movavg = executableScript.execute(
149+
vars,
150+
values.subList(fromIndex, toIndex).stream()
151+
.mapToDouble(Double::doubleValue)
152+
.toArray()
153+
);
132154

133155
List<InternalAggregation> aggs = StreamSupport
134156
.stream(bucket.getAggregations().spliterator(), false)
135157
.map(InternalAggregation.class::cast)
136158
.collect(Collectors.toList());
137159
aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<>(), metaData()));
138160
newBucket = factory.createBucket(factory.getKey(bucket), bucket.getDocCount(), new InternalAggregations(aggs));
139-
values.offer(thisBucketValue);
161+
index++;
140162
}
141163
newBuckets.add(newBucket);
142164
}
143165

144166
return factory.createAggregation(newBuckets);
145167
}
168+
169+
private int clamp(int index, List<Double> list) {
170+
if (index < 0) {
171+
return 0;
172+
}
173+
if (index > list.size()) {
174+
return list.size();
175+
}
176+
return index;
177+
}
146178
}

server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilderSerializationTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.elasticsearch.common.io.stream.Writeable;
2323
import org.elasticsearch.common.xcontent.XContentParser;
2424
import org.elasticsearch.script.Script;
25-
import org.elasticsearch.search.aggregations.pipeline.MovFnPipelineAggregationBuilder;
2625
import org.elasticsearch.test.AbstractSerializingTestCase;
2726

2827
import java.io.IOException;
@@ -31,7 +30,14 @@ public class MovFnPipelineAggregationBuilderSerializationTests extends AbstractS
3130

3231
@Override
3332
protected MovFnPipelineAggregationBuilder createTestInstance() {
34-
return new MovFnPipelineAggregationBuilder(randomAlphaOfLength(10), "foo", new Script("foo"), randomIntBetween(1, 10));
33+
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder(
34+
randomAlphaOfLength(10),
35+
"foo",
36+
new Script("foo"),
37+
randomIntBetween(1, 10)
38+
);
39+
builder.setShift(randomIntBetween(1, 10));
40+
return builder;
3541
}
3642

3743
@Override

server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import java.util.Map;
5454
import java.util.Set;
5555
import java.util.function.Consumer;
56+
import java.util.stream.Collectors;
5657

5758
import static org.hamcrest.Matchers.equalTo;
5859
import static org.mockito.Mockito.mock;
@@ -79,25 +80,42 @@ public class MovFnUnitTests extends AggregatorTestCase {
7980
private static final List<Integer> datasetValues = Arrays.asList(1,2,3,4,5,6,7,8,9,10);
8081

8182
public void testMatchAllDocs() throws IOException {
82-
Query query = new MatchAllDocsQuery();
83+
check(0, List.of(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
84+
}
85+
86+
public void testShift() throws IOException {
87+
check(1, List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0));
88+
check(5, List.of(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN));
89+
check(-5, List.of(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0));
90+
}
91+
92+
public void testWideWindow() throws IOException {
8393
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
94+
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 100);
95+
builder.setShift(50);
96+
check(builder, script, List.of(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0));
97+
}
8498

99+
private void check(int shift, List<Double> expected) throws IOException {
100+
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
101+
MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3);
102+
builder.setShift(shift);
103+
check(builder, script, expected);
104+
}
105+
106+
private void check(MovFnPipelineAggregationBuilder builder, Script script, List<Double> expected) throws IOException {
107+
Query query = new MatchAllDocsQuery();
85108
DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo");
86109
aggBuilder.calendarInterval(DateHistogramInterval.DAY).field(DATE_FIELD);
87110
aggBuilder.subAggregation(new AvgAggregationBuilder("avg").field(VALUE_FIELD));
88-
aggBuilder.subAggregation(new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3));
111+
aggBuilder.subAggregation(builder);
89112

90113
executeTestCase(query, aggBuilder, histogram -> {
91-
assertEquals(10, histogram.getBuckets().size());
92114
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
93-
for (int i = 0; i < buckets.size(); i++) {
94-
if (i == 0) {
95-
assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("mov_fn"))).value(), equalTo(Double.NaN));
96-
} else {
97-
assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("mov_fn"))).value(), equalTo(((double) i)));
98-
}
99-
100-
}
115+
List<Double> actual = buckets.stream()
116+
.map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value())
117+
.collect(Collectors.toList());
118+
assertThat(actual, equalTo(expected));
101119
}, 1000, script);
102120
}
103121

0 commit comments

Comments
 (0)