Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Pass accountId to EMRServerlessClientFactory.getClient #2822

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Pass accountId to EMRServerlessClientFactory.getClient (#2783)
Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
(cherry picked from commit e24b51f)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Jun 28, 2024
commit 5581203cca940c8d5f8e6104740dfd792fb7bfc3
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public interface EMRServerlessClientFactory {
/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @param accountId Account ID of the requester. It will be used to decide the cluster.
* @return An {@link EMRServerlessClient} instance.
*/
EMRServerlessClient getClient();
EMRServerlessClient getClient(String accountId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.metrics.MetricsService;

/** Implementation of {@link EMRServerlessClientFactory}. */
@RequiredArgsConstructor
public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory {

Expand All @@ -27,13 +26,8 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor
private EMRServerlessClient emrServerlessClient;
private String region;

/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @return An {@link EMRServerlessClient} instance.
*/
@Override
public EMRServerlessClient getClient() {
public EMRServerlessClient getClient(String accountId) {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(
new NullAsyncQueryRequestContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@ public class QueryHandlerFactory {
private final EMRServerlessClientFactory emrServerlessClientFactory;
private final MetricsService metricsService;

public RefreshQueryHandler getRefreshQueryHandler() {
public RefreshQueryHandler getRefreshQueryHandler(String accountId) {
return new RefreshQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
flintIndexMetadataService,
leaseManager,
flintIndexOpFactory,
metricsService);
}

public StreamingQueryHandler getStreamingQueryHandler() {
public StreamingQueryHandler getStreamingQueryHandler(String accountId) {
return new StreamingQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
leaseManager,
metricsService);
}

public BatchQueryHandler getBatchQueryHandler() {
public BatchQueryHandler getBatchQueryHandler(String accountId) {
return new BatchQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
leaseManager,
metricsService);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ public DispatchQueryResponse dispatch(
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
}
}

Expand All @@ -74,28 +75,28 @@ private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchConte
}

private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(
IndexQueryDetails indexQueryDetails) {
DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) {
if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
return queryHandlerFactory.getStreamingQueryHandler();
return queryHandlerFactory.getStreamingQueryHandler(dispatchQueryRequest.getAccountId());
} else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) {
// Create should be handled by batch handler. This is to avoid DROP index incorrectly cancel
// an interactive job.
return queryHandlerFactory.getBatchQueryHandler();
return queryHandlerFactory.getBatchQueryHandler(dispatchQueryRequest.getAccountId());
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// Manual refresh should be handled by batch handler
return queryHandlerFactory.getRefreshQueryHandler();
return queryHandlerFactory.getRefreshQueryHandler(dispatchQueryRequest.getAccountId());
} else {
return getDefaultAsyncQueryHandler();
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId());
}
}

@NotNull
private AsyncQueryHandler getDefaultAsyncQueryHandler() {
private AsyncQueryHandler getDefaultAsyncQueryHandler(String accountId) {
return sessionManager.isEnabled()
? queryHandlerFactory.getInteractiveQueryHandler()
: queryHandlerFactory.getBatchQueryHandler();
: queryHandlerFactory.getBatchQueryHandler(accountId);
}

@NotNull
Expand Down Expand Up @@ -143,11 +144,11 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
return queryHandlerFactory.getRefreshQueryHandler();
return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId());
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
return queryHandlerFactory.getStreamingQueryHandler();
return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId());
} else {
return queryHandlerFactory.getBatchQueryHandler();
return queryHandlerFactory.getBatchQueryHandler(asyncQueryJobMetadata.getAccountId());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public Session createSession(
.sessionId(sessionIdProvider.getSessionId(request))
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.serverlessClient(emrServerlessClientFactory.getClient(request.getAccountId()))
.build();
session.open(request, asyncQueryRequestContext);
return session;
Expand Down Expand Up @@ -65,7 +65,7 @@ public Optional<Session> getSession(String sessionId, String dataSourceName) {
.sessionId(sessionId)
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.serverlessClient(emrServerlessClientFactory.getClient(model.get().getAccountId()))
.sessionModel(model.get())
.sessionInactivityTimeoutMilli(
sessionConfigSupplier.getSessionInactivityTimeoutMillis())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel)
throws InterruptedException, TimeoutException {
String applicationId = flintIndexStateModel.getApplicationId();
String jobId = flintIndexStateModel.getJobId();
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient =
emrServerlessClientFactory.getClient(flintIndexStateModel.getAccountId());
try {
emrServerlessClient.cancelJobRun(
flintIndexStateModel.getApplicationId(), flintIndexStateModel.getJobId(), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
@ExtendWith(MockitoExtension.class)
public class EMRServerlessClientFactoryImplTest {

public static final String ACCOUNT_ID = "accountId";
@Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
@Mock private MetricsService metricsService;

Expand All @@ -30,7 +31,9 @@ public void testGetClient() {
.thenReturn(createSparkExecutionEngineConfig());
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);
EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient();

EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID);

Assertions.assertNotNull(emrserverlessClient);
}

Expand All @@ -41,16 +44,16 @@ public void testGetClientWithChangeInSetting() {
.thenReturn(sparkExecutionEngineConfig);
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);
EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient();
EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID);
Assertions.assertNotNull(emrserverlessClient);

EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient(ACCOUNT_ID);
Assertions.assertEquals(emrServerlessClient1, emrserverlessClient);

sparkExecutionEngineConfig.setRegion(TestConstants.US_WEST_REGION);
when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any()))
.thenReturn(sparkExecutionEngineConfig);
EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient(ACCOUNT_ID);
Assertions.assertNotEquals(emrServerlessClient2, emrserverlessClient);
Assertions.assertNotEquals(emrServerlessClient2, emrServerlessClient1);
}
Expand All @@ -60,9 +63,11 @@ public void testGetClientWithException() {
when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())).thenReturn(null);
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class, emrServerlessClientFactory::getClient);
IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID));

Assertions.assertEquals(
"Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config"
+ " in cluster settings to enable them.",
Expand All @@ -77,9 +82,11 @@ public void testGetClientWithExceptionWithNullRegion() {
.thenReturn(sparkExecutionEngineConfig);
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class, emrServerlessClientFactory::getClient);
IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID));

Assertions.assertEquals(
"Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config"
+ " in cluster settings to enable them.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void setUp() {

@Test
void testDispatchSelectQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -179,7 +179,7 @@ void testDispatchSelectQuery() {

@Test
void testDispatchSelectQueryWithLakeFormation() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -220,7 +220,7 @@ void testDispatchSelectQueryWithLakeFormation() {

@Test
void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -262,7 +262,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {

@Test
void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -368,7 +368,7 @@ void testDispatchSelectQueryFailedCreateSession() {

@Test
void testDispatchCreateAutoRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand Down Expand Up @@ -413,7 +413,7 @@ void testDispatchCreateAutoRefreshIndexQuery() {

@Test
void testDispatchCreateManualRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -456,7 +456,7 @@ void testDispatchCreateManualRefreshIndexQuery() {

@Test
void testDispatchWithPPLQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -499,7 +499,7 @@ void testDispatchWithPPLQuery() {

@Test
void testDispatchQueryWithoutATableAndDataSourceName() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -540,7 +540,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() {

@Test
void testDispatchIndexQueryWithoutADatasourceName() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand Down Expand Up @@ -585,7 +585,7 @@ void testDispatchIndexQueryWithoutADatasourceName() {

@Test
void testDispatchMaterializedViewQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(INDEX_TAG_KEY, "flint_mv_1");
Expand Down Expand Up @@ -630,7 +630,7 @@ void testDispatchMaterializedViewQuery() {

@Test
void testDispatchShowMVQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -671,7 +671,7 @@ void testDispatchShowMVQuery() {

@Test
void testRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -712,7 +712,7 @@ void testRefreshIndexQuery() {

@Test
void testDispatchDescribeIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -753,7 +753,7 @@ void testDispatchDescribeIndexQuery() {

@Test
void testDispatchAlterToAutoRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand Down Expand Up @@ -906,7 +906,7 @@ void testDispatchWithUnSupportedDataSourceType() {

@Test
void testCancelJob() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
Expand Down Expand Up @@ -968,7 +968,7 @@ void testCancelQueryWithInvalidStatementId() {

@Test
void testCancelQueryWithNoSessionId() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
Expand All @@ -982,7 +982,7 @@ void testCancelQueryWithNoSessionId() {

@Test
void testGetQueryResponse() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID))
.thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING)));
// simulate result index is not created yet
Expand Down Expand Up @@ -1079,7 +1079,7 @@ void testGetQueryResponseWithSuccess() {

@Test
void testDispatchQueryWithExtraSparkSubmitParameters() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
.thenReturn(dataSourceMetadata);
Expand Down
Loading
Loading