Skip to content

[ML] adds support for non-numeric mapped types #40220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public class DataFrameMessages {
"Failed to create composite aggregation from pivot function";
public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_INVALID =
"Data frame transform configuration [{0}] has invalid elements";

public static final String DATA_FRAME_UNABLE_TO_GATHER_FIELD_MAPPINGS = "Failed to gather field mappings for index [{0}]";
public static final String LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_QUERY =
"Failed to parse query for data frame transform";
public static final String LOG_DATA_FRAME_TRANSFORM_CONFIGURATION_BAD_GROUP_BY =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Set;

import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

public class DataFramePivotRestIT extends DataFrameRestTestCase {
Expand Down Expand Up @@ -267,6 +268,52 @@ public void testPreviewTransform() throws Exception {
});
}

public void testPivotWithMaxOnDateField() throws Exception {
String transformId = "simpleDateHistogramPivotWithMaxTime";
String dataFrameIndex = "pivot_reviews_via_date_histogram_with_max_time";
setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, dataFrameIndex);

final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId,
BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

String config = "{"
+ " \"source\": \"" + REVIEWS_INDEX_NAME + "\","
+ " \"dest\": \"" + dataFrameIndex + "\",";

config +=" \"pivot\": { \n" +
" \"group_by\": {\n" +
" \"by_day\": {\"date_histogram\": {\n" +
" \"interval\": \"1d\",\"field\":\"timestamp\",\"format\":\"yyyy-MM-DD\"\n" +
" }}\n" +
" },\n" +
" \n" +
" \"aggs\" :{\n" +
" \"avg_rating\": {\n" +
" \"avg\": {\"field\": \"stars\"}\n" +
" },\n" +
" \"timestamp\": {\n" +
" \"max\": {\"field\": \"timestamp\"}\n" +
" }\n" +
" }}"
+ "}";

createDataframeTransformRequest.setJsonEntity(config);
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
assertTrue(indexExists(dataFrameIndex));

startAndWaitForTransform(transformId, dataFrameIndex, BASIC_AUTH_VALUE_DATA_FRAME_ADMIN_WITH_SOME_DATA_ACCESS);

// we expect 21 documents as there shall be 21 days worth of docs
Map<String, Object> indexStats = getAsMap(dataFrameIndex + "/_stats");
assertEquals(21, XContentMapValues.extractValue("_all.total.docs.count", indexStats));
assertOnePivotValue(dataFrameIndex + "/_search?q=by_day:2017-01-15", 3.82);
Map<String, Object> searchResult = getAsMap(dataFrameIndex + "/_search?q=by_day:2017-01-15");
String actual = (String) ((List<?>) XContentMapValues.extractValue("hits.hits._source.timestamp", searchResult)).get(0);
// Do `containsString` as actual ending timestamp is indeterminate due to how data is generated
assertThat(actual, containsString("2017-01-15T20:"));
}

private void assertOnePivotValue(String query, double expected) throws IOException {
Map<String, Object> searchResult = getAsMap(query);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.dataframe.action.PreviewDataFrameTransformAction;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig;
import org.elasticsearch.xpack.dataframe.transforms.pivot.Pivot;

import java.util.List;
Expand Down Expand Up @@ -57,9 +58,11 @@ protected void doExecute(Task task,
return;
}

Pivot pivot = new Pivot(request.getConfig().getSource(),
request.getConfig().getQueryConfig().getQuery(),
request.getConfig().getPivotConfig());
final DataFrameTransformConfig config = request.getConfig();

Pivot pivot = new Pivot(config.getSource(),
config.getQueryConfig().getQuery(),
config.getPivotConfig());

getPreview(pivot, ActionListener.wrap(
previewResponse -> listener.onResponse(new PreviewDataFrameTransformAction.Response(previewResponse)),
Expand All @@ -68,18 +71,24 @@ protected void doExecute(Task task,
}

private void getPreview(Pivot pivot, ActionListener<List<Map<String, Object>>> listener) {
ClientHelper.executeWithHeadersAsync(threadPool.getThreadContext().getHeaders(),
ClientHelper.DATA_FRAME_ORIGIN,
client,
SearchAction.INSTANCE,
pivot.buildSearchRequest(null),
ActionListener.wrap(
r -> {
final CompositeAggregation agg = r.getAggregations().get(COMPOSITE_AGGREGATION_NAME);
DataFrameIndexerTransformStats stats = new DataFrameIndexerTransformStats();
listener.onResponse(pivot.extractResults(agg, stats).collect(Collectors.toList()));
},
listener::onFailure
));
pivot.deduceMappings(client, ActionListener.wrap(
deducedMappings -> {
ClientHelper.executeWithHeadersAsync(threadPool.getThreadContext().getHeaders(),
ClientHelper.DATA_FRAME_ORIGIN,
client,
SearchAction.INSTANCE,
pivot.buildSearchRequest(null),
ActionListener.wrap(
r -> {
final CompositeAggregation agg = r.getAggregations().get(COMPOSITE_AGGREGATION_NAME);
DataFrameIndexerTransformStats stats = new DataFrameIndexerTransformStats();
listener.onResponse(pivot.extractResults(agg, deducedMappings, stats).collect(Collectors.toList()));
},
listener::onFailure
));
},
listener::onFailure
));

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public DataFrameIndexer(Executor executor, AtomicReference<IndexerState> initial

protected abstract DataFrameTransformConfig getConfig();

protected abstract Map<String, String> getFieldMappings();

@Override
protected void onStartJob(long now) {
QueryBuilder queryBuilder = getConfig().getQueryConfig().getQuery();
Expand All @@ -70,7 +72,7 @@ private Stream<IndexRequest> processBucketsToIndexRequests(CompositeAggregation
final DataFrameTransformConfig transformConfig = getConfig();
String indexName = transformConfig.getDestination();

return pivot.extractResults(agg, getStats()).map(document -> {
return pivot.extractResults(agg, getFieldMappings(), getStats()).map(document -> {
XContentBuilder builder;
try {
builder = jsonBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.elasticsearch.xpack.core.scheduler.SchedulerEngine.Event;
import org.elasticsearch.xpack.dataframe.checkpoint.DataFrameTransformsCheckpointService;
import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
import org.elasticsearch.xpack.dataframe.transforms.pivot.SchemaUtil;

import java.util.Map;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -230,6 +231,7 @@ protected class ClientDataFrameIndexer extends DataFrameIndexer {
private final DataFrameTransformsConfigManager transformsConfigManager;
private final DataFrameTransformsCheckpointService transformsCheckpointService;
private final String transformId;
private Map<String, String> fieldMappings = null;

private DataFrameTransformConfig transformConfig = null;

Expand All @@ -248,6 +250,11 @@ protected DataFrameTransformConfig getConfig() {
return transformConfig;
}

@Override
protected Map<String, String> getFieldMappings() {
return fieldMappings;
}

@Override
protected String getJobId() {
return transformId;
Expand Down Expand Up @@ -279,6 +286,27 @@ public synchronized boolean maybeTriggerAsyncJob(long now) {
DataFrameMessages.getMessage(DataFrameMessages.DATA_FRAME_TRANSFORM_CONFIGURATION_INVALID, transformId));
}

if (fieldMappings == null) {
CountDownLatch latch = new CountDownLatch(1);
SchemaUtil.getDestinationFieldMappings(client, transformConfig.getDestination(), new LatchedActionListener<>(
ActionListener.wrap(
destinationMappings -> fieldMappings = destinationMappings,
e -> {
throw new RuntimeException(
DataFrameMessages.getMessage(DataFrameMessages.DATA_FRAME_UNABLE_TO_GATHER_FIELD_MAPPINGS,
transformConfig.getDestination()),
e);
}), latch));
try {
latch.await(LOAD_TRANSFORM_TIMEOUT_IN_SECONDS, TimeUnit.SECONDS);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL: I thought fieldMappings would have to be volatile for the assignment to be visible to the awaiting thread but latch.await is a synchronisation point. https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/package-summary.html#MemoryVisibility.

I presume we are blocking here because super.maybeTriggerAsyncJob(now); must execute on the calling thread.

} catch (InterruptedException e) {
throw new RuntimeException(
DataFrameMessages.getMessage(DataFrameMessages.DATA_FRAME_UNABLE_TO_GATHER_FIELD_MAPPINGS,
transformConfig.getDestination()),
e);
}
}

return super.maybeTriggerAsyncJob(now);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.Map;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.dataframe.transforms.pivot.SchemaUtil.isNumericType;

final class AggregationResultUtils {
private static final Logger logger = LogManager.getLogger(AggregationResultUtils.class);

Expand All @@ -30,30 +32,38 @@ final class AggregationResultUtils {
* @param agg The aggregation result
* @param groups The original groupings used for querying
* @param aggregationBuilders the aggregation used for querying
* @param dataFrameIndexerTransformStats stats collector
* @param fieldTypeMap A Map containing "field-name": "type" entries to determine the appropriate type for the aggregation results.
* @param stats stats collector
* @return a map containing the results of the aggregation in a consumable way
*/
public static Stream<Map<String, Object>> extractCompositeAggregationResults(CompositeAggregation agg,
GroupConfig groups,
Collection<AggregationBuilder> aggregationBuilders,
DataFrameIndexerTransformStats dataFrameIndexerTransformStats) {
GroupConfig groups,
Collection<AggregationBuilder> aggregationBuilders,
Map<String, String> fieldTypeMap,
DataFrameIndexerTransformStats stats) {
return agg.getBuckets().stream().map(bucket -> {
dataFrameIndexerTransformStats.incrementNumDocuments(bucket.getDocCount());
stats.incrementNumDocuments(bucket.getDocCount());

Map<String, Object> document = new HashMap<>();
groups.getGroups().keySet().forEach(destinationFieldName -> {
document.put(destinationFieldName, bucket.getKey().get(destinationFieldName));
});
groups.getGroups().keySet().forEach(destinationFieldName ->
document.put(destinationFieldName, bucket.getKey().get(destinationFieldName)));

for (AggregationBuilder aggregationBuilder : aggregationBuilders) {
String aggName = aggregationBuilder.getName();
final String fieldType = fieldTypeMap.get(aggName);

// TODO: support other aggregation types
Aggregation aggResult = bucket.getAggregations().get(aggName);

if (aggResult instanceof NumericMetricsAggregation.SingleValue) {
NumericMetricsAggregation.SingleValue aggResultSingleValue = (SingleValue) aggResult;
document.put(aggName, aggResultSingleValue.value());
// If the type is numeric, simply gather the `value` type, otherwise utilize `getValueAsString` so we don't lose
// formatted outputs.
if (isNumericType(fieldType)) {
document.put(aggName, aggResultSingleValue.value());
} else {
document.put(aggName, aggResultSingleValue.getValueAsString());
}
} else {
// Execution should never reach this point!
// Creating transforms with unsupported aggregations shall not be possible
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ public SearchRequest buildSearchRequest(Map<String, Object> position) {
}

public Stream<Map<String, Object>> extractResults(CompositeAggregation agg,
DataFrameIndexerTransformStats dataFrameIndexerTransformStats) {
Map<String, String> fieldTypeMap,
DataFrameIndexerTransformStats dataFrameIndexerTransformStats) {

GroupConfig groups = config.getGroupConfig();
Collection<AggregationBuilder> aggregationBuilders = config.getAggregationConfig().getAggregatorFactories();

return AggregationResultUtils.extractCompositeAggregationResults(agg, groups, aggregationBuilders, dataFrameIndexerTransformStats);
return AggregationResultUtils.extractCompositeAggregationResults(agg,
groups,
aggregationBuilders,
fieldTypeMap,
dataFrameIndexerTransformStats);
}

private void runTestQuery(Client client, final ActionListener<Boolean> listener) {
Expand All @@ -99,7 +104,7 @@ private void runTestQuery(Client client, final ActionListener<Boolean> listener)
}
listener.onResponse(true);
}, e->{
listener.onFailure(new RuntimeException("Failed to test query",e));
listener.onFailure(new RuntimeException("Failed to test query", e));
}));
}

Expand Down
Loading