Skip to content

Commit 5758a11

Browse files
author
Hendrik Muhs
authored
[ML-DataFrame] remove array arguments for group_by (#38895)
remove array arguments for group_by
1 parent 2dabf4a commit 5758a11

File tree

20 files changed

+363
-179
lines changed

20 files changed

+363
-179
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
public final class DataFrameField {
1515

1616
// common parse fields
17+
public static final ParseField AGGREGATIONS = new ParseField("aggregations");
18+
public static final ParseField AGGS = new ParseField("aggs");
1719
public static final ParseField ID = new ParseField("id");
1820
public static final ParseField TRANSFORMS = new ParseField("transforms");
1921
public static final ParseField COUNT = new ParseField("count");
22+
public static final ParseField GROUP_BY = new ParseField("group_by");
2023
public static final ParseField TIMEOUT = new ParseField("timeout");
2124
public static final ParseField WAIT_FOR_COMPLETION = new ParseField("wait_for_completion");
2225
public static final ParseField STATS_FIELD = new ParseField("stats");

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,19 @@ public class DataFrameMessages {
3333
"Failed to parse transform configuration for data frame transform [{0}]";
3434
public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_NO_TRANSFORM =
3535
"Data frame transform configuration must specify exactly 1 function";
36+
public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_GROUP_BY =
37+
"Data frame pivot transform configuration must specify at least 1 group_by";
38+
public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_AGGREGATION =
39+
"Data frame pivot transform configuration must specify at least 1 aggregation";
3640
public static final String DATA_FRAME_TRANSFORM_PIVOT_FAILED_TO_CREATE_COMPOSITE_AGGREGATION =
3741
"Failed to create composite aggregation from pivot function";
3842
public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_INVALID =
3943
"Data frame transform configuration [{0}] has invalid elements";
4044

4145
public static final String LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_QUERY =
4246
"Failed to parse query for data frame transform";
47+
public static final String LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_GROUP_BY =
48+
"Failed to parse group_by for data frame pivot transform";
4349
public static final String LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_AGGREGATION =
4450
"Failed to parse aggregation for data frame pivot transform";
4551

x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFramePivotRestIT.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ public void testHistogramPivot() throws Exception {
8787

8888

8989
config += " \"pivot\": {"
90-
+ " \"group_by\": [ {"
90+
+ " \"group_by\": {"
9191
+ " \"every_2\": {"
9292
+ " \"histogram\": {"
9393
+ " \"interval\": 2,\"field\":\"stars\""
94-
+ " } } } ],"
94+
+ " } } },"
9595
+ " \"aggregations\": {"
9696
+ " \"avg_rating\": {"
9797
+ " \"avg\": {"
@@ -125,11 +125,11 @@ public void testBiggerPivot() throws Exception {
125125

126126

127127
config += " \"pivot\": {"
128-
+ " \"group_by\": [ {"
128+
+ " \"group_by\": {"
129129
+ " \"reviewer\": {"
130130
+ " \"terms\": {"
131131
+ " \"field\": \"user_id\""
132-
+ " } } } ],"
132+
+ " } } },"
133133
+ " \"aggregations\": {"
134134
+ " \"avg_rating\": {"
135135
+ " \"avg\": {"
@@ -199,11 +199,11 @@ public void testDateHistogramPivot() throws Exception {
199199

200200

201201
config += " \"pivot\": {"
202-
+ " \"group_by\": [ {"
202+
+ " \"group_by\": {"
203203
+ " \"by_day\": {"
204204
+ " \"date_histogram\": {"
205205
+ " \"interval\": \"1d\",\"field\":\"timestamp\",\"format\":\"yyyy-MM-DD\""
206-
+ " } } } ],"
206+
+ " } } },"
207207
+ " \"aggregations\": {"
208208
+ " \"avg_rating\": {"
209209
+ " \"avg\": {"

x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameRestTestCase.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ protected void createPivotReviewsTransform(String transformId, String dataFrameI
125125
}
126126

127127
config += " \"pivot\": {"
128-
+ " \"group_by\": [ {"
128+
+ " \"group_by\": {"
129129
+ " \"reviewer\": {"
130130
+ " \"terms\": {"
131131
+ " \"field\": \"user_id\""
132-
+ " } } } ],"
132+
+ " } } },"
133133
+ " \"aggregations\": {"
134134
+ " \"avg_rating\": {"
135135
+ " \"avg\": {"

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationConfig.java

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,27 @@ public static AggregationConfig fromXContent(final XContentParser parser, boolea
7070
NamedXContentRegistry registry = parser.getXContentRegistry();
7171
Map<String, Object> source = parser.mapOrdered();
7272
AggregatorFactories.Builder aggregations = null;
73-
try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(source);
74-
XContentParser sourceParser = XContentType.JSON.xContent().createParser(registry, LoggingDeprecationHandler.INSTANCE,
75-
BytesReference.bytes(xContentBuilder).streamInput())) {
76-
sourceParser.nextToken();
77-
aggregations = AggregatorFactories.parseAggregators(sourceParser);
78-
} catch (Exception e) {
73+
74+
if (source.isEmpty()) {
7975
if (lenient) {
80-
logger.warn(DataFrameMessages.LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_AGGREGATION, e);
76+
logger.warn(DataFrameMessages.DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_AGGREGATION);
8177
} else {
82-
throw e;
78+
throw new IllegalArgumentException(DataFrameMessages.DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_AGGREGATION);
79+
}
80+
} else {
81+
try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(source);
82+
XContentParser sourceParser = XContentType.JSON.xContent().createParser(registry, LoggingDeprecationHandler.INSTANCE,
83+
BytesReference.bytes(xContentBuilder).streamInput())) {
84+
sourceParser.nextToken();
85+
aggregations = AggregatorFactories.parseAggregators(sourceParser);
86+
} catch (Exception e) {
87+
if (lenient) {
88+
logger.warn(DataFrameMessages.LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_AGGREGATION, e);
89+
} else {
90+
throw e;
91+
}
8392
}
8493
}
85-
8694
return new AggregationConfig(source, aggregations);
8795
}
8896

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/AggregationResultUtils.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,22 @@ final class AggregationResultUtils {
2727
* Extracts aggregation results from a composite aggregation and puts it into a map.
2828
*
2929
* @param agg The aggregation result
30-
* @param sources The original sources used for querying
30+
* @param groups The original groupings used for querying
3131
* @param aggregationBuilders the aggregation used for querying
3232
* @param dataFrameIndexerTransformStats stats collector
3333
* @return a map containing the results of the aggregation in a consumable way
3434
*/
3535
public static Stream<Map<String, Object>> extractCompositeAggregationResults(CompositeAggregation agg,
36-
Iterable<GroupConfig> sources, Collection<AggregationBuilder> aggregationBuilders,
36+
GroupConfig groups, Collection<AggregationBuilder> aggregationBuilders,
3737
DataFrameIndexerTransformStats dataFrameIndexerTransformStats) {
3838
return agg.getBuckets().stream().map(bucket -> {
3939
dataFrameIndexerTransformStats.incrementNumDocuments(bucket.getDocCount());
4040

4141
Map<String, Object> document = new HashMap<>();
42-
for (GroupConfig source : sources) {
43-
String destinationFieldName = source.getDestinationFieldName();
42+
groups.getGroups().keySet().forEach(destinationFieldName -> {
4443
document.put(destinationFieldName, bucket.getKey().get(destinationFieldName));
45-
}
44+
});
45+
4646
for (AggregationBuilder aggregationBuilder : aggregationBuilders) {
4747
String aggName = aggregationBuilder.getName();
4848

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/DateHistogramGroupSource.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ public static DateHistogramGroupSource fromXContent(final XContentParser parser,
8383
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
8484
}
8585

86+
@Override
87+
public Type getType() {
88+
return Type.DATE_HISTOGRAM;
89+
}
90+
8691
public long getInterval() {
8792
return interval;
8893
}

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/GroupConfig.java

Lines changed: 98 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,29 @@
66

77
package org.elasticsearch.xpack.dataframe.transforms.pivot;
88

9+
import org.apache.logging.log4j.LogManager;
10+
import org.apache.logging.log4j.Logger;
911
import org.elasticsearch.common.ParsingException;
12+
import org.elasticsearch.common.bytes.BytesReference;
1013
import org.elasticsearch.common.io.stream.StreamInput;
1114
import org.elasticsearch.common.io.stream.StreamOutput;
1215
import org.elasticsearch.common.io.stream.Writeable;
16+
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
17+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
1318
import org.elasticsearch.common.xcontent.ToXContentObject;
1419
import org.elasticsearch.common.xcontent.XContentBuilder;
20+
import org.elasticsearch.common.xcontent.XContentFactory;
1521
import org.elasticsearch.common.xcontent.XContentParser;
22+
import org.elasticsearch.common.xcontent.XContentType;
23+
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
24+
import org.elasticsearch.xpack.core.dataframe.DataFrameMessages;
25+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1626
import org.elasticsearch.xpack.dataframe.transforms.pivot.SingleGroupSource.Type;
1727

1828
import java.io.IOException;
29+
import java.util.LinkedHashMap;
1930
import java.util.Locale;
31+
import java.util.Map;
2032
import java.util.Objects;
2133

2234
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
@@ -26,58 +38,53 @@
2638
*/
2739
public class GroupConfig implements Writeable, ToXContentObject {
2840

29-
private final String destinationFieldName;
30-
private final SingleGroupSource.Type groupType;
31-
private final SingleGroupSource<?> groupSource;
41+
private static final Logger logger = LogManager.getLogger(GroupConfig.class);
3242

33-
public GroupConfig(final String destinationFieldName, final SingleGroupSource.Type groupType, final SingleGroupSource<?> groupSource) {
34-
this.destinationFieldName = Objects.requireNonNull(destinationFieldName);
35-
this.groupType = Objects.requireNonNull(groupType);
36-
this.groupSource = Objects.requireNonNull(groupSource);
43+
private final Map<String, Object> source;
44+
private final Map<String, SingleGroupSource<?>> groups;
45+
46+
public GroupConfig(final Map<String, Object> source, final Map<String, SingleGroupSource<?>> groups) {
47+
this.source = ExceptionsHelper.requireNonNull(source, DataFrameField.GROUP_BY.getPreferredName());
48+
this.groups = groups;
3749
}
3850

3951
public GroupConfig(StreamInput in) throws IOException {
40-
destinationFieldName = in.readString();
41-
groupType = Type.fromId(in.readByte());
42-
switch (groupType) {
43-
case TERMS:
44-
groupSource = in.readOptionalWriteable(TermsGroupSource::new);
45-
break;
46-
case HISTOGRAM:
47-
groupSource = in.readOptionalWriteable(HistogramGroupSource::new);
48-
break;
49-
case DATE_HISTOGRAM:
50-
groupSource = in.readOptionalWriteable(DateHistogramGroupSource::new);
51-
break;
52-
default:
53-
throw new IOException("Unknown group type");
54-
}
52+
source = in.readMap();
53+
groups = in.readMap(StreamInput::readString, (stream) -> {
54+
Type groupType = Type.fromId(stream.readByte());
55+
switch (groupType) {
56+
case TERMS:
57+
return new TermsGroupSource(stream);
58+
case HISTOGRAM:
59+
return new HistogramGroupSource(stream);
60+
case DATE_HISTOGRAM:
61+
return new DateHistogramGroupSource(stream);
62+
default:
63+
throw new IOException("Unknown group type");
64+
}
65+
});
5566
}
5667

57-
@Override
58-
public void writeTo(StreamOutput out) throws IOException {
59-
out.writeString(destinationFieldName);
60-
out.writeByte(groupType.getId());
61-
out.writeOptionalWriteable(groupSource);
68+
public Map <String, SingleGroupSource<?>> getGroups() {
69+
return groups;
6270
}
6371

64-
@Override
65-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
66-
builder.startObject();
67-
builder.startObject(destinationFieldName);
68-
69-
builder.field(groupType.value(), groupSource);
70-
builder.endObject();
71-
builder.endObject();
72-
return builder;
72+
public boolean isValid() {
73+
return this.groups != null;
7374
}
7475

75-
public String getDestinationFieldName() {
76-
return destinationFieldName;
76+
@Override
77+
public void writeTo(StreamOutput out) throws IOException {
78+
out.writeMap(source);
79+
out.writeMap(groups, StreamOutput::writeString, (stream, value) -> {
80+
stream.writeByte(value.getType().getId());
81+
value.writeTo(stream);
82+
});
7783
}
7884

79-
public SingleGroupSource<?> getGroupSource() {
80-
return groupSource;
85+
@Override
86+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
87+
return builder.map(source);
8188
}
8289

8390
@Override
@@ -92,19 +99,44 @@ public boolean equals(Object other) {
9299

93100
final GroupConfig that = (GroupConfig) other;
94101

95-
return Objects.equals(this.destinationFieldName, that.destinationFieldName) && Objects.equals(this.groupType, that.groupType)
96-
&& Objects.equals(this.groupSource, that.groupSource);
102+
return Objects.equals(this.source, that.source) && Objects.equals(this.groups, that.groups);
97103
}
98104

99105
@Override
100106
public int hashCode() {
101-
return Objects.hash(destinationFieldName, groupType, groupSource);
107+
return Objects.hash(source, groups);
102108
}
103109

104110
public static GroupConfig fromXContent(final XContentParser parser, boolean lenient) throws IOException {
105-
String destinationFieldName;
106-
Type groupType;
107-
SingleGroupSource<?> groupSource;
111+
NamedXContentRegistry registry = parser.getXContentRegistry();
112+
Map<String, Object> source = parser.mapOrdered();
113+
Map<String, SingleGroupSource<?>> groups = null;
114+
115+
if (source.isEmpty()) {
116+
if (lenient) {
117+
logger.warn(DataFrameMessages.DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_GROUP_BY);
118+
} else {
119+
throw new IllegalArgumentException(DataFrameMessages.DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_GROUP_BY);
120+
}
121+
} else {
122+
try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(source);
123+
XContentParser sourceParser = XContentType.JSON.xContent().createParser(registry, LoggingDeprecationHandler.INSTANCE,
124+
BytesReference.bytes(xContentBuilder).streamInput())) {
125+
groups = parseGroupConfig(sourceParser, lenient);
126+
} catch (Exception e) {
127+
if (lenient) {
128+
logger.warn(DataFrameMessages.LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_GROUP_BY, e);
129+
} else {
130+
throw e;
131+
}
132+
}
133+
}
134+
return new GroupConfig(source, groups);
135+
}
136+
137+
private static Map<String, SingleGroupSource<?>> parseGroupConfig(final XContentParser parser,
138+
boolean lenient) throws IOException {
139+
LinkedHashMap<String, SingleGroupSource<?>> groups = new LinkedHashMap<>();
108140

109141
// be parsing friendly, whether the token needs to be advanced or not (similar to what ObjectParser does)
110142
XContentParser.Token token;
@@ -116,19 +148,21 @@ public static GroupConfig fromXContent(final XContentParser parser, boolean leni
116148
throw new ParsingException(parser.getTokenLocation(), "Failed to parse object: Expected START_OBJECT but was: " + token);
117149
}
118150
}
119-
token = parser.nextToken();
120-
ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser::getTokenLocation);
121-
destinationFieldName = parser.currentName();
122-
token = parser.nextToken();
123-
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser::getTokenLocation);
124-
token = parser.nextToken();
125-
ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser::getTokenLocation);
126-
groupType = SingleGroupSource.Type.valueOf(parser.currentName().toUpperCase(Locale.ROOT));
127-
128-
token = parser.nextToken();
129-
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser::getTokenLocation);
130-
131-
switch (groupType) {
151+
152+
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
153+
154+
ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser::getTokenLocation);
155+
String destinationFieldName = parser.currentName();
156+
token = parser.nextToken();
157+
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser::getTokenLocation);
158+
token = parser.nextToken();
159+
ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser::getTokenLocation);
160+
Type groupType = SingleGroupSource.Type.valueOf(parser.currentName().toUpperCase(Locale.ROOT));
161+
162+
token = parser.nextToken();
163+
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser::getTokenLocation);
164+
SingleGroupSource<?> groupSource;
165+
switch (groupType) {
132166
case TERMS:
133167
groupSource = TermsGroupSource.fromXContent(parser, lenient);
134168
break;
@@ -140,11 +174,12 @@ public static GroupConfig fromXContent(final XContentParser parser, boolean leni
140174
break;
141175
default:
142176
throw new ParsingException(parser.getTokenLocation(), "invalid grouping type: " + groupType);
143-
}
177+
}
144178

145-
parser.nextToken();
146-
parser.nextToken();
179+
parser.nextToken();
147180

148-
return new GroupConfig(destinationFieldName, groupType, groupSource);
181+
groups.put(destinationFieldName, groupSource);
182+
}
183+
return groups;
149184
}
150185
}

0 commit comments

Comments
 (0)