Skip to content

Commit

Permalink
Replace latches with CompletableFutures for extensions (opensearch-pr…
Browse files Browse the repository at this point in the history
…oject#5978)

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
  • Loading branch information
ryanbogan authored Jan 23, 2023
1 parent f847fd5 commit b1de3b6
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 73 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Added support for feature flags in opensearch.yml ([#4959](https://github.com/opensearch-project/OpenSearch/pull/4959))
- Add query for initialized extensions ([#5658](https://github.com/opensearch-project/OpenSearch/pull/5658))
- Add update-index-settings allowlist for searchable snapshot ([#5907](https://github.com/opensearch-project/OpenSearch/pull/5907))
- Replace latches with CompletableFutures for extensions ([#5646](https://github.com/opensearch-project/OpenSearch/pull/5646))

### Dependencies
- Update nebula-publishing-plugin to 19.2.0 ([#5704](https://github.com/opensearch-project/OpenSearch/pull/5704))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -198,7 +200,7 @@ public void initializeServicesAndRestHandler(
*
* @param request which was sent by an extension.
*/
public ExtensionActionResponse handleTransportRequest(ExtensionActionRequest request) throws InterruptedException {
public ExtensionActionResponse handleTransportRequest(ExtensionActionRequest request) throws Exception {
return extensionTransportActionsHandler.sendTransportRequestToExtension(request);
}

Expand Down Expand Up @@ -401,13 +403,17 @@ public String executor() {
new InitializeExtensionRequest(transportService.getLocalNode(), extension),
initializeExtensionResponseHandler
);
// TODO: make asynchronous
inProgressFuture.get(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS);
} catch (Exception e) {
try {
throw e;
} catch (Exception e1) {
logger.error(e.toString());
inProgressFuture.orTimeout(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join();
} catch (CompletionException e) {
if (e.getCause() instanceof TimeoutException) {
logger.info("No response from extension to request.");
}
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
} else if (e.getCause() instanceof Error) {
throw (Error) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
Expand Down Expand Up @@ -462,7 +468,7 @@ public void handleResponse(AcknowledgedResponse response) {

@Override
public void handleException(TransportException exp) {

inProgressIndexNameFuture.completeExceptionally(exp);
}

@Override
Expand Down Expand Up @@ -506,20 +512,21 @@ public void beforeIndexRemoved(
new IndicesModuleRequest(indexModule),
acknowledgedResponseHandler
);
// TODO: make asynchronous
inProgressIndexNameFuture.get(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS);
logger.info("Received ack response from Extension");
} catch (Exception e) {
try {
throw e;
} catch (Exception e1) {
logger.error(e.toString());
}
inProgressIndexNameFuture.whenComplete((r, e) -> {
if (e != null) {
inProgressFuture.complete(response);
} else if (e == null) {
inProgressFuture.completeExceptionally(e);
}
});
} catch (Exception ex) {
inProgressFuture.completeExceptionally(ex);
}
}
});
} else {
inProgressFuture.complete(response);
}
inProgressFuture.complete(response);
}

@Override
Expand All @@ -542,14 +549,18 @@ public String executor() {
new IndicesModuleRequest(indexModule),
indicesModuleResponseHandler
);
// TODO: make asynchronous
inProgressFuture.get(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS);
inProgressFuture.orTimeout(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join();
logger.info("Received response from Extension");
} catch (Exception e) {
try {
throw e;
} catch (Exception e1) {
logger.error(e.toString());
} catch (CompletionException e) {
if (e.getCause() instanceof TimeoutException) {
logger.info("No response from extension to request.");
}
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
} else if (e.getCause() instanceof Error) {
throw (Error) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
* This class manages TransportActions for extensions
Expand Down Expand Up @@ -108,10 +110,9 @@ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransport
* @return {@link TransportResponse} which is sent back to the transport action invoker.
* @throws InterruptedException when message transport fails.
*/
public TransportResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request)
throws InterruptedException {
public TransportResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request) throws Exception {
DiscoveryExtensionNode extension = extensionIdMap.get(request.getUniqueId());
final CountDownLatch inProgressLatch = new CountDownLatch(1);
final CompletableFuture<ExtensionActionResponse> inProgressFuture = new CompletableFuture<>();
final TransportActionResponseToExtension response = new TransportActionResponseToExtension(new byte[0]);
client.execute(
ExtensionProxyAction.INSTANCE,
Expand All @@ -120,19 +121,32 @@ public TransportResponse handleTransportActionRequestFromExtension(TransportActi
@Override
public void onResponse(ExtensionActionResponse actionResponse) {
response.setResponseBytes(actionResponse.getResponseBytes());
inProgressLatch.countDown();
inProgressFuture.complete(actionResponse);
}

@Override
public void onFailure(Exception exp) {
logger.debug("Transport request failed", exp);
byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8);
response.setResponseBytes(responseBytes);
inProgressLatch.countDown();
inProgressFuture.completeExceptionally(exp);
}
}
);
inProgressLatch.await(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS);
try {
inProgressFuture.orTimeout(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join();
} catch (CompletionException e) {
if (e.getCause() instanceof TimeoutException) {
logger.info("No response from extension to request.");
}
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
} else if (e.getCause() instanceof Error) {
throw (Error) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
return response;
}

Expand All @@ -143,12 +157,12 @@ public void onFailure(Exception exp) {
* @return {@link ExtensionActionResponse} which encapsulates the transport response from the extension.
* @throws InterruptedException when message transport fails.
*/
public ExtensionActionResponse sendTransportRequestToExtension(ExtensionActionRequest request) throws InterruptedException {
public ExtensionActionResponse sendTransportRequestToExtension(ExtensionActionRequest request) throws Exception {
DiscoveryExtensionNode extension = actionsMap.get(request.getAction());
if (extension == null) {
throw new ActionNotFoundTransportException(request.getAction());
}
final CountDownLatch inProgressLatch = new CountDownLatch(1);
final CompletableFuture<ExtensionActionResponse> inProgressFuture = new CompletableFuture<>();
final ExtensionActionResponse extensionActionResponse = new ExtensionActionResponse(new byte[0]);
final TransportResponseHandler<ExtensionActionResponse> extensionActionResponseTransportResponseHandler =
new TransportResponseHandler<ExtensionActionResponse>() {
Expand All @@ -161,15 +175,15 @@ public ExtensionActionResponse read(StreamInput in) throws IOException {
@Override
public void handleResponse(ExtensionActionResponse response) {
extensionActionResponse.setResponseBytes(response.getResponseBytes());
inProgressLatch.countDown();
inProgressFuture.complete(response);
}

@Override
public void handleException(TransportException exp) {
logger.debug("Transport request failed", exp);
byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8);
extensionActionResponse.setResponseBytes(responseBytes);
inProgressLatch.countDown();
inProgressFuture.completeExceptionally(exp);
}

@Override
Expand All @@ -187,7 +201,20 @@ public String executor() {
} catch (Exception e) {
logger.info("Failed to send transport action to extension " + extension.getName(), e);
}
inProgressLatch.await(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS);
try {
inProgressFuture.orTimeout(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join();
} catch (CompletionException e) {
if (e.getCause() instanceof TimeoutException) {
logger.info("No response from extension to request.");
}
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
} else if (e.getCause() instanceof Error) {
throw (Error) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
return extensionActionResponse;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableList;
Expand Down Expand Up @@ -122,7 +125,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
emptyList(),
false
);
final CountDownLatch inProgressLatch = new CountDownLatch(1);
final CompletableFuture<RestExecuteOnExtensionResponse> inProgressFuture = new CompletableFuture<>();
final TransportResponseHandler<RestExecuteOnExtensionResponse> restExecuteOnExtensionResponseHandler = new TransportResponseHandler<
RestExecuteOnExtensionResponse>() {

Expand All @@ -143,15 +146,13 @@ public void handleResponse(RestExecuteOnExtensionResponse response) {
if (response.isContentConsumed()) {
request.content();
}
inProgressFuture.complete(response);
}

@Override
public void handleException(TransportException exp) {
logger.debug("REST request failed", exp);
// Status is already defaulted to 500 (INTERNAL_SERVER_ERROR)
byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8);
restExecuteOnExtensionResponse.setContent(responseBytes);
inProgressLatch.countDown();
inProgressFuture.completeExceptionally(exp);
}

@Override
Expand All @@ -172,15 +173,24 @@ public String executor() {
new ExtensionRestRequest(method, path, params, contentType, content, requestIssuerIdentity),
restExecuteOnExtensionResponseHandler
);
try {
inProgressLatch.await(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
inProgressFuture.orTimeout(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join();
} catch (CompletionException e) {
Throwable cause = e.getCause();
if (cause instanceof TimeoutException) {
return channel -> channel.sendResponse(
new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, "No response from extension to request.")
);
}
} catch (Exception e) {
logger.info("Failed to send REST Actions to extension " + discoveryExtensionNode.getName(), e);
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
} else if (e.getCause() instanceof Error) {
throw (Error) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
} catch (Exception ex) {
logger.info("Failed to send REST Actions to extension " + discoveryExtensionNode.getName(), ex);
return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, ex.getMessage()));
}
BytesRestResponse restResponse = new BytesRestResponse(
restExecuteOnExtensionResponse.getStatus(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
import org.opensearch.test.transport.MockTransportService;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.ConnectTransportException;
import org.opensearch.transport.NodeNotConnectedException;
import org.opensearch.transport.Transport;
import org.opensearch.transport.TransportResponse;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -427,23 +429,23 @@ public void testInitialize() throws Exception {

mockLogAppender.addExpectation(
new MockLogAppender.SeenEventExpectation(
"Connect Transport Exception 1",
"Node Not Connected Exception 1",
"org.opensearch.extensions.ExtensionsManager",
Level.ERROR,
"ConnectTransportException[[firstExtension][127.0.0.0:9300] connect_timeout[30s]]"
"[secondExtension][127.0.0.1:9301] Node not connected"
)
);

mockLogAppender.addExpectation(
new MockLogAppender.SeenEventExpectation(
"Connect Transport Exception 2",
"Node Not Connected Exception 2",
"org.opensearch.extensions.ExtensionsManager",
Level.ERROR,
"ConnectTransportException[[secondExtension][127.0.0.1:9301] connect_exception]; nested: ConnectException[Connection refused];"
"[firstExtension][127.0.0.0:9300] Node not connected"
)
);

extensionsManager.initialize();
expectThrows(ConnectTransportException.class, () -> extensionsManager.initialize());

// Test needs to be changed to mock the connection between the local node and an extension. Assert statment is commented out for
// now.
Expand Down Expand Up @@ -831,21 +833,8 @@ public void testOnIndexModule() throws Exception {
new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)),
Collections.emptyMap()
);
expectThrows(NodeNotConnectedException.class, () -> extensionsManager.onIndexModule(indexModule));

try (MockLogAppender mockLogAppender = MockLogAppender.createForLoggers(LogManager.getLogger(ExtensionsManager.class))) {

mockLogAppender.addExpectation(
new MockLogAppender.SeenEventExpectation(
"IndicesModuleRequest Failure",
"org.opensearch.extensions.ExtensionsManager",
Level.ERROR,
"IndicesModuleRequest failed"
)
);

extensionsManager.onIndexModule(indexModule);
mockLogAppender.assertAllExpectationsMatched();
}
}

private void initialize(ExtensionsManager extensionsManager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.ActionNotFoundTransportException;
import org.opensearch.transport.NodeNotConnectedException;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.nio.MockNioTransport;

Expand Down Expand Up @@ -172,10 +173,6 @@ public void testSendTransportRequestToExtension() throws InterruptedException {
);
assertTrue(response.getStatus());

ExtensionActionResponse extensionResponse = extensionTransportActionsHandler.sendTransportRequestToExtension(request);
assertEquals(
"Request failed: [firstExtension][127.0.0.0:9300] Node not connected",
new String(extensionResponse.getResponseBytes(), StandardCharsets.UTF_8)
);
expectThrows(NodeNotConnectedException.class, () -> extensionTransportActionsHandler.sendTransportRequestToExtension(request));
}
}

0 comments on commit b1de3b6

Please sign in to comment.