Skip to content

Commit

Permalink
Added tests.
Browse files Browse the repository at this point in the history
Signed-off-by: dblock <dblock@amazon.com>
  • Loading branch information
dblock committed Sep 29, 2023
1 parent 46ab55f commit c16dc56
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ private void registerRequestHandler(DynamicActionRegistry dynamicActionRegistry)
* Loads a single extension
* @param extension The extension to be loaded
*/
public void loadExtension(Extension extension) throws IOException {
public DiscoveryExtensionNode loadExtension(Extension extension) throws IOException {
validateExtension(extension);
DiscoveryExtensionNode discoveryExtensionNode = new DiscoveryExtensionNode(
extension.getName(),
Expand All @@ -314,6 +314,12 @@ public void loadExtension(Extension extension) throws IOException {
extensionIdMap.put(extension.getUniqueId(), discoveryExtensionNode);
extensionSettingsMap.put(extension.getUniqueId(), extension);
logger.info("Loaded extension with uniqueId " + extension.getUniqueId() + ": " + extension);
return discoveryExtensionNode;
}

public void initializeExtension(Extension extension) throws IOException {
DiscoveryExtensionNode node = loadExtension(extension);
initializeExtensionNode(node);
}

private void validateField(String fieldName, String value) throws IOException {
Expand All @@ -340,13 +346,11 @@ private void validateExtension(Extension extension) throws IOException {
*/
public void initialize() {
for (DiscoveryExtensionNode extension : extensionIdMap.values()) {
if (! initializedExtensions.containsKey(extension)) {
initializeExtension(extension);
}
initializeExtensionNode(extension);
}
}

private void initializeExtension(DiscoveryExtensionNode extension) {
private void initializeExtensionNode(DiscoveryExtensionNode extension) {

final CompletableFuture<InitializeExtensionResponse> inProgressFuture = new CompletableFuture<>();
final TransportResponseHandler<InitializeExtensionResponse> initializeExtensionResponseHandler = new TransportResponseHandler<
Expand Down Expand Up @@ -386,7 +390,7 @@ public String executor() {
transportService.getThreadPool().generic().execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
logger.warn(String.format("Error registering extension: %s", extension.getId()), e);
logger.warn("Error registering extension: " + extension.getId(), e);
extensionIdMap.remove(extension.getId());
if (e.getCause() instanceof ConnectTransportException) {
logger.info("No response from extension to request.", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public TransportResponse handleRegisterRestActionsRequest(
) throws Exception {
DiscoveryExtensionNode discoveryExtensionNode = extensionIdMap.get(restActionsRequest.getUniqueId());
if (discoveryExtensionNode == null) {
throw new IllegalStateException(String.format("Missing extension node for %s", restActionsRequest.getUniqueId()));
throw new IllegalStateException("Missing extension node for " + restActionsRequest.getUniqueId());
}
RestHandler handler = new RestSendToExtensionAction(
restActionsRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
extAdditionalSettings
);
try {
extensionsManager.loadExtension(extension);
extensionsManager.initialize();
extensionsManager.initializeExtension(extension);
} catch (CompletionException e) {
Throwable cause = e.getCause();
if (cause instanceof TimeoutException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.discovery.InitializeExtensionRequest;
import org.opensearch.env.Environment;
import org.opensearch.env.EnvironmentSettingsResponse;
import org.opensearch.extensions.ExtensionsSettings.Extension;
Expand Down Expand Up @@ -67,7 +68,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

Expand All @@ -78,6 +78,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -410,36 +411,37 @@ public void testInitialize() throws Exception {
)
);

// Test needs to be changed to mock the connection between the local node and an extension. Assert statment is commented out for
// now.
// Test needs to be changed to mock the connection between the local node and an extension.
// Link to issue: https://github.com/opensearch-project/OpenSearch/issues/4045
// mockLogAppender.assertAllExpectationsMatched();
}
}

public void testInitializeExtensionTwice() throws Exception {
public void testInitializeExtension() throws Exception {
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);

ThreadPool mockThreadPool = spy(threadPool);
ExecutorService mockExecutorService = mock(ExecutorService.class);
when(mockThreadPool.generic()).thenReturn(mockExecutorService);

TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
mockThreadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null,
null,
Collections.emptySet(),
NoopTracer.INSTANCE
TransportService mockTransportService = spy(
new TransportService(
Settings.EMPTY,
mock(Transport.class),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null,
null,
Collections.emptySet(),
NoopTracer.INSTANCE
)
);

doNothing().when(mockTransportService).connectToExtensionNode(any(DiscoveryExtensionNode.class));

doNothing().when(mockTransportService)
.sendRequest(any(DiscoveryExtensionNode.class), anyString(), any(InitializeExtensionRequest.class), any());

extensionsManager.initializeServicesAndRestHandler(
actionModule,
settingsModule,
transportService,
mockTransportService,
clusterService,
settings,
client,
Expand All @@ -458,8 +460,7 @@ public void testInitializeExtensionTwice() throws Exception {
null
);

extensionsManager.loadExtension(firstExtension);
extensionsManager.initialize();
extensionsManager.initializeExtension(firstExtension);

Extension secondExtension = new Extension(
"secondExtension",
Expand All @@ -473,12 +474,18 @@ public void testInitializeExtensionTwice() throws Exception {
null
);

extensionsManager.loadExtension(secondExtension);
extensionsManager.initialize();
extensionsManager.initializeExtension(secondExtension);

// When execution is mocked, the successful registration callback is not called and the extension is never added to
// registered extensions.
// verify(mockExecutorService, times(2)).execute(any());
ThreadPool.terminate(threadPool, 3, TimeUnit.SECONDS);

verify(mockTransportService, times(2)).connectToExtensionNode(any(DiscoveryExtensionNode.class));

verify(mockTransportService, times(2)).sendRequest(
any(DiscoveryExtensionNode.class),
anyString(),
any(InitializeExtensionRequest.class),
any()
);
}

public void testHandleRegisterRestActionsRequest() throws Exception {
Expand Down Expand Up @@ -515,20 +522,20 @@ public void testHandleRegisterRestActionsRequestRequiresDiscoveryNode() throws E
);
}

public void testHandleRegisterTwoRestActionsRequest() throws Exception {
public void testHandleRegisterRestActionsRequestMultiple() throws Exception {

ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);

List<String> actionsList = List.of("GET /foo foo", "PUT /bar bar", "POST /baz baz");
List<String> deprecatedActionsList = List.of("GET /deprecated/foo foo_deprecated", "It's deprecated!");
for (int i = 0; i < 2; i++) {
String uniqueIdStr = String.format("uniqueid-%d", i);
String uniqueIdStr = "uniqueid-%d" + i;

Set<Setting<?>> additionalSettings = extAwarePlugin.getExtensionSettings().stream().collect(Collectors.toSet());
ExtensionScopedSettings extensionScopedSettings = new ExtensionScopedSettings(additionalSettings);
Extension firstExtension = new Extension(
String.format("Extension %s", i),
"Extension %s" + i,
uniqueIdStr,
"127.0.0.0",
"9300",
Expand Down

0 comments on commit c16dc56

Please sign in to comment.