Skip to content

[ML-DataFrame] remove array arguments for group_by #38895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
public final class DataFrameField {

// common parse fields
public static final ParseField AGGREGATIONS = new ParseField("aggregations");
public static final ParseField AGGS = new ParseField("aggs");
public static final ParseField ID = new ParseField("id");
public static final ParseField TRANSFORMS = new ParseField("transforms");
public static final ParseField COUNT = new ParseField("count");
public static final ParseField GROUP_BY = new ParseField("group_by");
public static final ParseField TIMEOUT = new ParseField("timeout");
public static final ParseField WAIT_FOR_COMPLETION = new ParseField("wait_for_completion");
public static final ParseField STATS_FIELD = new ParseField("stats");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@ public class DataFrameMessages {
"Failed to parse transform configuration for data frame transform [{0}]";
public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_NO_TRANSFORM =
"Data frame transform configuration must specify exactly 1 function";
public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_GROUP_BY =
"Data frame pivot transform configuration must specify at least 1 group_by";
public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_AGGREGATION =
"Data frame pivot transform configuration must specify at least 1 aggregation";
public static final String DATA_FRAME_TRANSFORM_PIVOT_FAILED_TO_CREATE_COMPOSITE_AGGREGATION =
"Failed to create composite aggregation from pivot function";
public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_INVALID =
"Data frame transform configuration [{0}] has invalid elements";

public static final String LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_QUERY =
"Failed to parse query for data frame transform";
public static final String LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_GROUP_BY =
"Failed to parse group_by for data frame pivot transform";
public static final String LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_AGGREGATION =
"Failed to parse aggregation for data frame pivot transform";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ public void testHistogramPivot() throws Exception {


config += " \"pivot\": {"
+ " \"group_by\": [ {"
+ " \"group_by\": {"
+ " \"every_2\": {"
+ " \"histogram\": {"
+ " \"interval\": 2,\"field\":\"stars\""
+ " } } } ],"
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
Expand Down Expand Up @@ -125,11 +125,11 @@ public void testBiggerPivot() throws Exception {


config += " \"pivot\": {"
+ " \"group_by\": [ {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } } ],"
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
Expand Down Expand Up @@ -199,11 +199,11 @@ public void testDateHistogramPivot() throws Exception {


config += " \"pivot\": {"
+ " \"group_by\": [ {"
+ " \"group_by\": {"
+ " \"by_day\": {"
+ " \"date_histogram\": {"
+ " \"interval\": \"1d\",\"field\":\"timestamp\",\"format\":\"yyyy-MM-DD\""
+ " } } } ],"
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ protected void createPivotReviewsTransform(String transformId, String dataFrameI
}

config += " \"pivot\": {"
+ " \"group_by\": [ {"
+ " \"group_by\": {"
+ " \"reviewer\": {"
+ " \"terms\": {"
+ " \"field\": \"user_id\""
+ " } } } ],"
+ " } } },"
+ " \"aggregations\": {"
+ " \"avg_rating\": {"
+ " \"avg\": {"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,27 @@ public static AggregationConfig fromXContent(final XContentParser parser, boolea
NamedXContentRegistry registry = parser.getXContentRegistry();
Map<String, Object> source = parser.mapOrdered();
AggregatorFactories.Builder aggregations = null;
try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(source);
XContentParser sourceParser = XContentType.JSON.xContent().createParser(registry, LoggingDeprecationHandler.INSTANCE,
BytesReference.bytes(xContentBuilder).streamInput())) {
sourceParser.nextToken();
aggregations = AggregatorFactories.parseAggregators(sourceParser);
} catch (Exception e) {

if (source.isEmpty()) {
if (lenient) {
logger.warn(DataFrameMessages.LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_AGGREGATION, e);
logger.warn(DataFrameMessages.DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_AGGREGATION);
} else {
throw e;
throw new IllegalArgumentException(DataFrameMessages.DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_AGGREGATION);
}
} else {
try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(source);
XContentParser sourceParser = XContentType.JSON.xContent().createParser(registry, LoggingDeprecationHandler.INSTANCE,
BytesReference.bytes(xContentBuilder).streamInput())) {
sourceParser.nextToken();
aggregations = AggregatorFactories.parseAggregators(sourceParser);
} catch (Exception e) {
if (lenient) {
logger.warn(DataFrameMessages.LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_AGGREGATION, e);
} else {
throw e;
}
}
}

return new AggregationConfig(source, aggregations);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@ final class AggregationResultUtils {
* Extracts aggregation results from a composite aggregation and puts it into a map.
*
* @param agg The aggregation result
* @param sources The original sources used for querying
* @param groups The original groupings used for querying
* @param aggregationBuilders the aggregation used for querying
* @param dataFrameIndexerTransformStats stats collector
* @return a map containing the results of the aggregation in a consumable way
*/
public static Stream<Map<String, Object>> extractCompositeAggregationResults(CompositeAggregation agg,
Iterable<GroupConfig> sources, Collection<AggregationBuilder> aggregationBuilders,
GroupConfig groups, Collection<AggregationBuilder> aggregationBuilders,
DataFrameIndexerTransformStats dataFrameIndexerTransformStats) {
return agg.getBuckets().stream().map(bucket -> {
dataFrameIndexerTransformStats.incrementNumDocuments(bucket.getDocCount());

Map<String, Object> document = new HashMap<>();
for (GroupConfig source : sources) {
String destinationFieldName = source.getDestinationFieldName();
groups.getGroups().keySet().forEach(destinationFieldName -> {
document.put(destinationFieldName, bucket.getKey().get(destinationFieldName));
}
});

for (AggregationBuilder aggregationBuilder : aggregationBuilders) {
String aggName = aggregationBuilder.getName();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ public static DateHistogramGroupSource fromXContent(final XContentParser parser,
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}

@Override
public Type getType() {
return Type.DATE_HISTOGRAM;
}

public long getInterval() {
return interval;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,29 @@

package org.elasticsearch.xpack.dataframe.transforms.pivot;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.xpack.core.dataframe.DataFrameField;
import org.elasticsearch.xpack.core.dataframe.DataFrameMessages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.dataframe.transforms.pivot.SingleGroupSource.Type;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
Expand All @@ -26,58 +38,53 @@
*/
public class GroupConfig implements Writeable, ToXContentObject {

private final String destinationFieldName;
private final SingleGroupSource.Type groupType;
private final SingleGroupSource<?> groupSource;
private static final Logger logger = LogManager.getLogger(GroupConfig.class);

public GroupConfig(final String destinationFieldName, final SingleGroupSource.Type groupType, final SingleGroupSource<?> groupSource) {
this.destinationFieldName = Objects.requireNonNull(destinationFieldName);
this.groupType = Objects.requireNonNull(groupType);
this.groupSource = Objects.requireNonNull(groupSource);
private final Map<String, Object> source;
private final Map<String, SingleGroupSource<?>> groups;

public GroupConfig(final Map<String, Object> source, final Map<String, SingleGroupSource<?>> groups) {
this.source = ExceptionsHelper.requireNonNull(source, DataFrameField.GROUP_BY.getPreferredName());
this.groups = groups;
}

public GroupConfig(StreamInput in) throws IOException {
destinationFieldName = in.readString();
groupType = Type.fromId(in.readByte());
switch (groupType) {
case TERMS:
groupSource = in.readOptionalWriteable(TermsGroupSource::new);
break;
case HISTOGRAM:
groupSource = in.readOptionalWriteable(HistogramGroupSource::new);
break;
case DATE_HISTOGRAM:
groupSource = in.readOptionalWriteable(DateHistogramGroupSource::new);
break;
default:
throw new IOException("Unknown group type");
}
source = in.readMap();
groups = in.readMap(StreamInput::readString, (stream) -> {
Type groupType = Type.fromId(stream.readByte());
switch (groupType) {
case TERMS:
return new TermsGroupSource(stream);
case HISTOGRAM:
return new HistogramGroupSource(stream);
case DATE_HISTOGRAM:
return new DateHistogramGroupSource(stream);
default:
throw new IOException("Unknown group type");
}
});
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(destinationFieldName);
out.writeByte(groupType.getId());
out.writeOptionalWriteable(groupSource);
public Map <String, SingleGroupSource<?>> getGroups() {
return groups;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startObject(destinationFieldName);

builder.field(groupType.value(), groupSource);
builder.endObject();
builder.endObject();
return builder;
public boolean isValid() {
return this.groups != null;
}

public String getDestinationFieldName() {
return destinationFieldName;
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeMap(source);
out.writeMap(groups, StreamOutput::writeString, (stream, value) -> {
stream.writeByte(value.getType().getId());
value.writeTo(stream);
});
}

public SingleGroupSource<?> getGroupSource() {
return groupSource;
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(source);
}

@Override
Expand All @@ -92,19 +99,44 @@ public boolean equals(Object other) {

final GroupConfig that = (GroupConfig) other;

return Objects.equals(this.destinationFieldName, that.destinationFieldName) && Objects.equals(this.groupType, that.groupType)
&& Objects.equals(this.groupSource, that.groupSource);
return Objects.equals(this.source, that.source) && Objects.equals(this.groups, that.groups);
}

@Override
public int hashCode() {
return Objects.hash(destinationFieldName, groupType, groupSource);
return Objects.hash(source, groups);
}

public static GroupConfig fromXContent(final XContentParser parser, boolean lenient) throws IOException {
String destinationFieldName;
Type groupType;
SingleGroupSource<?> groupSource;
NamedXContentRegistry registry = parser.getXContentRegistry();
Map<String, Object> source = parser.mapOrdered();
Map<String, SingleGroupSource<?>> groups = null;

if (source.isEmpty()) {
if (lenient) {
logger.warn(DataFrameMessages.DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_GROUP_BY);
} else {
throw new IllegalArgumentException(DataFrameMessages.DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_GROUP_BY);
}
} else {
try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(source);
XContentParser sourceParser = XContentType.JSON.xContent().createParser(registry, LoggingDeprecationHandler.INSTANCE,
BytesReference.bytes(xContentBuilder).streamInput())) {
groups = parseGroupConfig(sourceParser, lenient);
} catch (Exception e) {
if (lenient) {
logger.warn(DataFrameMessages.LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_GROUP_BY, e);
} else {
throw e;
}
}
}
return new GroupConfig(source, groups);
}

private static Map<String, SingleGroupSource<?>> parseGroupConfig(final XContentParser parser,
boolean lenient) throws IOException {
LinkedHashMap<String, SingleGroupSource<?>> groups = new LinkedHashMap<>();

// be parsing friendly, whether the token needs to be advanced or not (similar to what ObjectParser does)
XContentParser.Token token;
Expand All @@ -116,19 +148,21 @@ public static GroupConfig fromXContent(final XContentParser parser, boolean leni
throw new ParsingException(parser.getTokenLocation(), "Failed to parse object: Expected START_OBJECT but was: " + token);
}
}
token = parser.nextToken();
ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser::getTokenLocation);
destinationFieldName = parser.currentName();
token = parser.nextToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser::getTokenLocation);
token = parser.nextToken();
ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser::getTokenLocation);
groupType = SingleGroupSource.Type.valueOf(parser.currentName().toUpperCase(Locale.ROOT));

token = parser.nextToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser::getTokenLocation);

switch (groupType) {

while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {

ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser::getTokenLocation);
String destinationFieldName = parser.currentName();
token = parser.nextToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser::getTokenLocation);
token = parser.nextToken();
ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser::getTokenLocation);
Type groupType = SingleGroupSource.Type.valueOf(parser.currentName().toUpperCase(Locale.ROOT));

token = parser.nextToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser::getTokenLocation);
SingleGroupSource<?> groupSource;
switch (groupType) {
case TERMS:
groupSource = TermsGroupSource.fromXContent(parser, lenient);
break;
Expand All @@ -140,11 +174,12 @@ public static GroupConfig fromXContent(final XContentParser parser, boolean leni
break;
default:
throw new ParsingException(parser.getTokenLocation(), "invalid grouping type: " + groupType);
}
}

parser.nextToken();
parser.nextToken();
parser.nextToken();

return new GroupConfig(destinationFieldName, groupType, groupSource);
groups.put(destinationFieldName, groupSource);
}
return groups;
}
}
Loading