Skip to content

Commit

Permalink
qpt change for elastic search (#1854)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trianz-Akshay authored Mar 29, 2024
1 parent dec75b6 commit 4449052
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ public void put(String endpoint, AwsRestHighLevelClient client)
*/
private void evictCache(boolean force)
{
Iterator<Map.Entry<String, CacheableAwsRestHighLevelClient.CacheEntry>> itr = clientCache.entrySet().iterator();
Iterator<Map.Entry<String, CacheEntry>> itr = clientCache.entrySet().iterator();
int removed = 0;
while (itr.hasNext()) {
CacheableAwsRestHighLevelClient.CacheEntry entry = itr.next().getValue();
CacheEntry entry = itr.next().getValue();
// If age of client is greater than the maximum allowed, remove it.
if (entry.getAge() > MAX_CACHE_AGE_MS) {
closeClient(entry.getClient());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

/**
* Used to resolve Elasticsearch complex structures to Apache Arrow Types.
* @see com.amazonaws.athena.connector.lambda.data.FieldResolver
* @see FieldResolver
*/
public class ElasticsearchFieldResolver
implements FieldResolver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
Expand All @@ -36,7 +38,9 @@
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.glue.GlueFieldLexer;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connectors.elasticsearch.qpt.ElasticsearchQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.glue.AWSGlue;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
Expand Down Expand Up @@ -110,8 +114,9 @@ public class ElasticsearchMetadataHandler
private final ElasticsearchDomainMapProvider domainMapProvider;

private ElasticsearchGlueTypeMapper glueTypeMapper;
private final ElasticsearchQueryPassthrough queryPassthrough = new ElasticsearchQueryPassthrough();

public ElasticsearchMetadataHandler(java.util.Map<String, String> configOptions)
public ElasticsearchMetadataHandler(Map<String, String> configOptions)
{
super(SOURCE_TYPE, configOptions);
this.awsGlue = getAwsGlue();
Expand All @@ -134,7 +139,7 @@ protected ElasticsearchMetadataHandler(
ElasticsearchDomainMapProvider domainMapProvider,
AwsRestHighLevelClientFactory clientFactory,
long queryTimeout,
java.util.Map<String, String> configOptions)
Map<String, String> configOptions)
{
super(awsGlue, keyFactory, awsSecretsManager, athena, SOURCE_TYPE, spillBucket, spillPrefix, configOptions);
this.awsGlue = awsGlue;
Expand Down Expand Up @@ -241,15 +246,7 @@ public GetTableResponse doGetTable(BlockAllocator allocator, GetTableRequest req
if (schema == null) {
String index = request.getTableName().getTableName();
String endpoint = getDomainEndpoint(request.getTableName().getSchemaName());
AwsRestHighLevelClient client = clientFactory.getOrCreateClient(endpoint);
try {
Map<String, Object> mappings = client.getMapping(index);
schema = ElasticsearchSchemaUtils.parseMapping(mappings);
}
catch (IOException error) {
throw new RuntimeException("Error retrieving mapping information for index (" +
index + "): " + error.getMessage(), error);
}
schema = getSchema(index, endpoint);
}

return new GetTableResponse(request.getCatalogName(), request.getTableName(),
Expand Down Expand Up @@ -285,15 +282,23 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
throws IOException
{
logger.debug("doGetSplits: enter - " + request);

String domain;
String indx;
// Get domain
String domain = request.getTableName().getSchemaName();
if (request.getConstraints().isQueryPassThrough()) {
domain = request.getConstraints().getQueryPassthroughArguments().get(ElasticsearchQueryPassthrough.SCHEMA);
indx = request.getConstraints().getQueryPassthroughArguments().get(ElasticsearchQueryPassthrough.INDEX);
}
else {
domain = request.getTableName().getSchemaName();
indx = request.getTableName().getTableName();
}

String endpoint = getDomainEndpoint(domain);
AwsRestHighLevelClient client = clientFactory.getOrCreateClient(endpoint);
// We send index request in case the table name is a data stream, a data stream can contains multiple indices which are created by ES
// For non data stream, index name is same as table name
GetIndexResponse indexResponse = client.indices().get(new GetIndexRequest(request.getTableName().getTableName()), RequestOptions.DEFAULT);
GetIndexResponse indexResponse = client.indices().get(new GetIndexRequest(indx), RequestOptions.DEFAULT);

Set<Split> splits = Arrays.stream(indexResponse.getIndices())
.flatMap(index -> getShardsIDsFromES(client, index) // get all shards for an index.
Expand All @@ -305,6 +310,46 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
return new GetSplitsResponse(request.getCatalogName(), splits);
}

@Override
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}

@Override
public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception
{
logger.debug("doGetQueryPassthroughSchema: enter - " + request);
if (!request.isQueryPassthrough()) {
throw new IllegalArgumentException("No Query passed through [{}]" + request);
}
queryPassthrough.verify(request.getQueryPassthroughArguments());
String index = request.getQueryPassthroughArguments().get(ElasticsearchQueryPassthrough.INDEX);
String endpoint = getDomainEndpoint(request.getQueryPassthroughArguments().get(ElasticsearchQueryPassthrough.SCHEMA));
Schema schema = getSchema(index, endpoint);

return new GetTableResponse(request.getCatalogName(), request.getTableName(),
(schema == null) ? SchemaBuilder.newBuilder().build() : schema, Collections.emptySet());
}

private Schema getSchema(String index, String endpoint)
{
Schema schema;
AwsRestHighLevelClient client = clientFactory.getOrCreateClient(endpoint);
try {
Map<String, Object> mappings = client.getMapping(index);
schema = ElasticsearchSchemaUtils.parseMapping(mappings);
}
catch (IOException error) {
throw new RuntimeException("Error retrieving mapping information for index (" +
index + ") ", error);
}
return schema;
}

/**
* Get all data streams from ES, one data stream can contains multiple indices which start with ".ds-xxxxxxxxxx"
* return empty if not supported.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.amazonaws.athena.connector.lambda.data.writers.extractors.Extractor;
import com.amazonaws.athena.connector.lambda.handlers.RecordHandler;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connectors.elasticsearch.qpt.ElasticsearchQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.athena.AmazonAthenaClientBuilder;
import com.amazonaws.services.s3.AmazonS3;
Expand All @@ -40,6 +41,8 @@
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.Scroll;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand All @@ -48,6 +51,7 @@

import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
Expand Down Expand Up @@ -83,8 +87,9 @@ public class ElasticsearchRecordHandler

private final AwsRestHighLevelClientFactory clientFactory;
private final ElasticsearchTypeUtils typeUtils;
private final ElasticsearchQueryPassthrough queryPassthrough = new ElasticsearchQueryPassthrough();

public ElasticsearchRecordHandler(java.util.Map<String, String> configOptions)
public ElasticsearchRecordHandler(Map<String, String> configOptions)
{
super(AmazonS3ClientBuilder.defaultClient(), AWSSecretsManagerClientBuilder.defaultClient(),
AmazonAthenaClientBuilder.defaultClient(), SOURCE_TYPE, configOptions);
Expand All @@ -103,7 +108,7 @@ protected ElasticsearchRecordHandler(
AwsRestHighLevelClientFactory clientFactory,
long queryTimeout,
long scrollTimeout,
java.util.Map<String, String> configOptions)
Map<String, String> configOptions)
{
super(amazonS3, secretsManager, amazonAthena, SOURCE_TYPE, configOptions);

Expand Down Expand Up @@ -133,14 +138,27 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor
QueryStatusChecker queryStatusChecker)
throws RuntimeException
{
logger.info("readWithConstraint - enter - Domain: {}, Index: {}, Mapping: {}",
recordsRequest.getTableName().getSchemaName(), recordsRequest.getTableName().getTableName(),
recordsRequest.getSchema());
String domain;
QueryBuilder query;
String index;
if (recordsRequest.getConstraints().isQueryPassThrough()) {
Map<String, String> qptArgs = recordsRequest.getConstraints().getQueryPassthroughArguments();
queryPassthrough.verify(qptArgs);
domain = qptArgs.get(ElasticsearchQueryPassthrough.SCHEMA);
index = qptArgs.get(ElasticsearchQueryPassthrough.INDEX);
query = QueryBuilders.wrapperQuery(qptArgs.get(ElasticsearchQueryPassthrough.QUERY));
}
else {
domain = recordsRequest.getTableName().getSchemaName();
index = recordsRequest.getSplit().getProperty(ElasticsearchMetadataHandler.INDEX_KEY);
query = ElasticsearchQueryUtils.getQuery(recordsRequest.getConstraints());
}

String domain = recordsRequest.getTableName().getSchemaName();
String endpoint = recordsRequest.getSplit().getProperty(domain);
String shard = recordsRequest.getSplit().getProperty(ElasticsearchMetadataHandler.SHARD_KEY);
String index = recordsRequest.getSplit().getProperty(ElasticsearchMetadataHandler.INDEX_KEY);
logger.info("readWithConstraint - enter - Domain: {}, Index: {}, Mapping: {}, Query: {}",
domain, index,
recordsRequest.getSchema(), query);
long numRows = 0;

if (queryStatusChecker.isQueryRunning()) {
Expand All @@ -154,7 +172,7 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor
.size(QUERY_BATCH_SIZE)
.timeout(new TimeValue(queryTimeout, TimeUnit.SECONDS))
.fetchSource(ElasticsearchQueryUtils.getProjection(recordsRequest.getSchema()))
.query(ElasticsearchQueryUtils.getQuery(recordsRequest.getConstraints()));
.query(query);

//init scroll
Scroll scroll = new Scroll(TimeValue.timeValueSeconds(this.scrollTimeout));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*-
* #%L
* athena-elasticsearch
* %%
* Copyright (C) 2019 - 2024 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.elasticsearch.qpt;

import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;

public class ElasticsearchQueryPassthrough implements QueryPassthroughSignature
{
private static final Logger LOGGER = LoggerFactory.getLogger(ElasticsearchQueryPassthrough.class);
// Constant value representing the name of the query.
public static final String NAME = "query";

// Constant value representing the domain of the query.
public static final String SCHEMA_NAME = "system";

// List of arguments for the query, statically initialized as it always contains the same value.
public static final String SCHEMA = "SCHEMA";
public static final String INDEX = "INDEX";
public static final String QUERY = "QUERY";

public static final List<String> ARGUMENTS = Arrays.asList(SCHEMA, INDEX, QUERY);

@Override
public String getFunctionSchema()
{
return SCHEMA_NAME;
}

@Override
public String getFunctionName()
{
return NAME;
}

@Override
public List<String> getFunctionArguments()
{
return ARGUMENTS;
}

@Override
public Logger getLogger()
{
return LOGGER;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public void doListSchemaNames()
"domain2", "endpoint2","domain3", "endpoint3"));

handler = new ElasticsearchMetadataHandler(awsGlue, new LocalKeyFactory(), awsSecretsManager, amazonAthena,
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, com.google.common.collect.ImmutableMap.of());
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, ImmutableMap.of());

ListSchemasRequest req = new ListSchemasRequest(fakeIdentity(), "queryId", "elasticsearch");
ListSchemasResponse realDomains = handler.doListSchemaNames(allocator, req);
Expand Down Expand Up @@ -205,7 +205,7 @@ public void doListTables()
when(domainMapProvider.getDomainMap(null)).thenReturn(ImmutableMap.of("movies",
"https://search-movies-ne3fcqzfipy6jcrew2wca6kyqu.us-east-1.es.amazonaws.com"));
handler = new ElasticsearchMetadataHandler(awsGlue, new LocalKeyFactory(), awsSecretsManager, amazonAthena,
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, com.google.common.collect.ImmutableMap.of());
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, ImmutableMap.of());

IndicesClient indices = mock(IndicesClient.class);
GetDataStreamResponse mockIndexResponse = mock(GetDataStreamResponse.class);
Expand Down Expand Up @@ -385,7 +385,7 @@ public void doGetTable()
when(domainMapProvider.getDomainMap(null)).thenReturn(ImmutableMap.of("movies",
"https://search-movies-ne3fcqzfipy6jcrew2wca6kyqu.us-east-1.es.amazonaws.com"));
handler = new ElasticsearchMetadataHandler(awsGlue, new LocalKeyFactory(), awsSecretsManager, amazonAthena,
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, com.google.common.collect.ImmutableMap.of());
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, ImmutableMap.of());
GetTableRequest req = new GetTableRequest(fakeIdentity(), "queryId", "elasticsearch",
new TableName("movies", "mishmash"), Collections.emptyMap());
GetTableResponse res = handler.doGetTable(allocator, req);
Expand Down Expand Up @@ -446,7 +446,7 @@ public void doGetSplits()

// Instantiate handler
handler = new ElasticsearchMetadataHandler(awsGlue, new LocalKeyFactory(), awsSecretsManager, amazonAthena,
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, com.google.common.collect.ImmutableMap.of());
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, ImmutableMap.of());

// Call doGetSplits()
MetadataResponse rawResponse = handler.doGetSplits(allocator, req);
Expand Down Expand Up @@ -493,7 +493,7 @@ public void convertFieldTest()
logger.info("convertFieldTest: enter");

handler = new ElasticsearchMetadataHandler(awsGlue, new LocalKeyFactory(), awsSecretsManager, amazonAthena,
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, com.google.common.collect.ImmutableMap.of());
"spill-bucket", "spill-prefix", domainMapProvider, clientFactory, 10, ImmutableMap.of());

Field field = handler.convertField("myscaled", "SCALED_FLOAT(10.51)");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ public void setUp()
when(mockScrollResponse.getHits()).thenReturn(null);
when(mockClient.scroll(any(), any())).thenReturn(mockScrollResponse);

handler = new ElasticsearchRecordHandler(amazonS3, awsSecretsManager, athena, clientFactory, 720, 60, com.google.common.collect.ImmutableMap.of());
handler = new ElasticsearchRecordHandler(amazonS3, awsSecretsManager, athena, clientFactory, 720, 60, ImmutableMap.of());

logger.info("setUpBefore - exit");
}
Expand Down

0 comments on commit 4449052

Please sign in to comment.