diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 6ec67709b8..8feeddcafc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -213,7 +213,8 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); if (sessionManager.isEnabled()) { - Session session; + Session session = null; + if (dispatchQueryRequest.getSessionId() != null) { // get session from request SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); @@ -222,8 +223,9 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ throw new IllegalArgumentException("no session found. " + sessionId); } session = createdSession.get(); - } else { - // create session if not exist + } + if (session == null || !session.isReady()) { + // create session if not exist or session dead/fail tags.put(JOB_TYPE_TAG_KEY, JobType.INTERACTIVE.getText()); session = sessionManager.createSession( diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 956275b04a..3221b33b2c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -6,7 +6,9 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; +import static org.opensearch.sql.spark.execution.session.SessionState.DEAD; import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; +import static org.opensearch.sql.spark.execution.session.SessionState.FAIL; import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; @@ -130,4 +132,9 @@ public Optional get(StatementId stID) { .statementModel(model) .build()); } + + @Override + public boolean isReady() { + return sessionModel.getSessionState() != DEAD && sessionModel.getSessionState() != FAIL; + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java index 4d919d5e2e..d3d3411ded 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -37,4 +37,7 @@ public interface Session { SessionModel getSessionModel(); SessionId getSessionId(); + + /** return true if session is ready to use. */ + boolean isReady(); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index f65049a7d9..6bc40c009b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -63,6 +63,7 @@ import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; @@ -390,6 +391,7 @@ public void withSessionCreateAsyncQueryFailed() { assertEquals("mock error", asyncQueryResults.getError()); } + // https://github.com/opensearch-project/sql/issues/2344 @Test public void createSessionMoreThanLimitFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); @@ -419,6 +421,65 @@ public void createSessionMoreThanLimitFailed() { "The maximum number of active sessions can be supported is 1", exception.getMessage()); } + // https://github.com/opensearch-project/sql/issues/2360 + @Test + public void recreateSessionIfNotReady() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + + // set sessionState to FAIL + setSessionState(first.getSessionId(), SessionState.FAIL); + + // 2. reuse session id + CreateAsyncQueryResponse second = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, first.getSessionId())); + + assertNotEquals(first.getSessionId(), second.getSessionId()); + + // set sessionState to FAIL + setSessionState(second.getSessionId(), SessionState.DEAD); + + // 3. reuse session id + CreateAsyncQueryResponse third = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, second.getSessionId())); + assertNotEquals(second.getSessionId(), third.getSessionId()); + } + + @Test + public void submitQueryInInvalidSessionThrowException() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + SessionId sessionId = SessionId.newSessionId(DATASOURCE); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select 1", DATASOURCE, LangType.SQL, sessionId.getSessionId()))); + assertEquals("no session found. " + sessionId, exception.getMessage()); + } + private DataSourceServiceImpl createDataSourceService() { String masterKey = "a57d991d9b573f75b9bba1df"; DataSourceMetadataStorage dataSourceMetadataStorage = @@ -536,6 +597,6 @@ void setSessionState(String sessionId, SessionState sessionState) { Optional model = getSession(stateStore, DATASOURCE).apply(sessionId); SessionModel updated = updateSessionState(stateStore, DATASOURCE).apply(model.get(), sessionState); - assertEquals(SessionState.RUNNING, updated.getSessionState()); + assertEquals(sessionState, updated.getSessionState()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index fc8623d51a..743274d46c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -315,6 +315,7 @@ void testDispatchSelectQueryReuseSession() { doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); + when(session.isReady()).thenReturn(true); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata);