Skip to content

Commit

Permalink
Added tests.
Browse files Browse the repository at this point in the history
Signed-off-by: dblock <dblock@amazon.com>
  • Loading branch information
dblock committed Oct 2, 2023
1 parent d1047d7 commit da21f03
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ private void registerRequestHandler(DynamicActionRegistry dynamicActionRegistry)
* Loads a single extension
* @param extension The extension to be loaded
*/
public void loadExtension(Extension extension) throws IOException {
public DiscoveryExtensionNode loadExtension(Extension extension) throws IOException {
validateExtension(extension);
DiscoveryExtensionNode discoveryExtensionNode = new DiscoveryExtensionNode(
extension.getName(),
Expand All @@ -314,6 +314,12 @@ public void loadExtension(Extension extension) throws IOException {
extensionIdMap.put(extension.getUniqueId(), discoveryExtensionNode);
extensionSettingsMap.put(extension.getUniqueId(), extension);
logger.info("Loaded extension with uniqueId " + extension.getUniqueId() + ": " + extension);
return discoveryExtensionNode;
}

public void initializeExtension(Extension extension) throws IOException {
DiscoveryExtensionNode node = loadExtension(extension);
initializeExtensionNode(node);
}

private void validateField(String fieldName, String value) throws IOException {
Expand All @@ -340,13 +346,11 @@ private void validateExtension(Extension extension) throws IOException {
*/
public void initialize() {
for (DiscoveryExtensionNode extension : extensionIdMap.values()) {
if (! initializedExtensions.containsKey(extension)) {
initializeExtension(extension);
}
initializeExtensionNode(extension);
}
}

private void initializeExtension(DiscoveryExtensionNode extension) {
public void initializeExtensionNode(DiscoveryExtensionNode extensionNode) {

final CompletableFuture<InitializeExtensionResponse> inProgressFuture = new CompletableFuture<>();
final TransportResponseHandler<InitializeExtensionResponse> initializeExtensionResponseHandler = new TransportResponseHandler<
Expand Down Expand Up @@ -386,8 +390,8 @@ public String executor() {
transportService.getThreadPool().generic().execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
logger.warn(String.format("Error registering extension: %s", extension.getId()), e);
extensionIdMap.remove(extension.getId());
logger.warn("Error registering extension: " + extensionNode.getId(), e);
extensionIdMap.remove(extensionNode.getId());
if (e.getCause() instanceof ConnectTransportException) {
logger.info("No response from extension to request.", e);
throw (ConnectTransportException) e.getCause();
Expand All @@ -402,11 +406,11 @@ public void onFailure(Exception e) {

@Override
protected void doRun() throws Exception {
transportService.connectToExtensionNode(extension);
transportService.connectToExtensionNode(extensionNode);
transportService.sendRequest(
extension,
extensionNode,
REQUEST_EXTENSION_ACTION_NAME,
new InitializeExtensionRequest(transportService.getLocalNode(), extension, issueServiceAccount(extension)),
new InitializeExtensionRequest(transportService.getLocalNode(), extensionNode, issueServiceAccount(extensionNode)),
initializeExtensionResponseHandler
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public TransportResponse handleRegisterRestActionsRequest(
) throws Exception {
DiscoveryExtensionNode discoveryExtensionNode = extensionIdMap.get(restActionsRequest.getUniqueId());
if (discoveryExtensionNode == null) {
throw new IllegalStateException(String.format("Missing extension node for %s", restActionsRequest.getUniqueId()));
throw new IllegalStateException("Missing extension node for " + restActionsRequest.getUniqueId());
}
RestHandler handler = new RestSendToExtensionAction(
restActionsRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
extAdditionalSettings
);
try {
extensionsManager.loadExtension(extension);
extensionsManager.initialize();
extensionsManager.initializeExtension(extension);
} catch (CompletionException e) {
Throwable cause = e.getCause();
if (cause instanceof TimeoutException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.discovery.InitializeExtensionRequest;
import org.opensearch.env.Environment;
import org.opensearch.env.EnvironmentSettingsResponse;
import org.opensearch.extensions.ExtensionsSettings.Extension;
Expand Down Expand Up @@ -67,7 +68,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

Expand All @@ -78,6 +78,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -410,36 +411,37 @@ public void testInitialize() throws Exception {
)
);

// Test needs to be changed to mock the connection between the local node and an extension. Assert statment is commented out for
// now.
// Test needs to be changed to mock the connection between the local node and an extension.
// Link to issue: https://github.com/opensearch-project/OpenSearch/issues/4045
// mockLogAppender.assertAllExpectationsMatched();
}
}

public void testInitializeExtensionTwice() throws Exception {
public void testInitializeExtension() throws Exception {
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);

ThreadPool mockThreadPool = spy(threadPool);
ExecutorService mockExecutorService = mock(ExecutorService.class);
when(mockThreadPool.generic()).thenReturn(mockExecutorService);

TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
mockThreadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null,
null,
Collections.emptySet(),
NoopTracer.INSTANCE
TransportService mockTransportService = spy(
new TransportService(
Settings.EMPTY,
mock(Transport.class),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null,
null,
Collections.emptySet(),
NoopTracer.INSTANCE
)
);

doNothing().when(mockTransportService).connectToExtensionNode(any(DiscoveryExtensionNode.class));

doNothing().when(mockTransportService)
.sendRequest(any(DiscoveryExtensionNode.class), anyString(), any(InitializeExtensionRequest.class), any());

extensionsManager.initializeServicesAndRestHandler(
actionModule,
settingsModule,
transportService,
mockTransportService,
clusterService,
settings,
client,
Expand All @@ -458,8 +460,7 @@ public void testInitializeExtensionTwice() throws Exception {
null
);

extensionsManager.loadExtension(firstExtension);
extensionsManager.initialize();
extensionsManager.initializeExtension(firstExtension);

Extension secondExtension = new Extension(
"secondExtension",
Expand All @@ -473,12 +474,18 @@ public void testInitializeExtensionTwice() throws Exception {
null
);

extensionsManager.loadExtension(secondExtension);
extensionsManager.initialize();
extensionsManager.initializeExtension(secondExtension);

// When execution is mocked, the successful registration callback is not called and the extension is never added to
// registered extensions.
// verify(mockExecutorService, times(2)).execute(any());
ThreadPool.terminate(threadPool, 3, TimeUnit.SECONDS);

verify(mockTransportService, times(2)).connectToExtensionNode(any(DiscoveryExtensionNode.class));

verify(mockTransportService, times(2)).sendRequest(
any(DiscoveryExtensionNode.class),
anyString(),
any(InitializeExtensionRequest.class),
any()
);
}

public void testHandleRegisterRestActionsRequest() throws Exception {
Expand Down Expand Up @@ -515,20 +522,20 @@ public void testHandleRegisterRestActionsRequestRequiresDiscoveryNode() throws E
);
}

public void testHandleRegisterTwoRestActionsRequest() throws Exception {
public void testHandleRegisterRestActionsRequestMultiple() throws Exception {

ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);

List<String> actionsList = List.of("GET /foo foo", "PUT /bar bar", "POST /baz baz");
List<String> deprecatedActionsList = List.of("GET /deprecated/foo foo_deprecated", "It's deprecated!");
for (int i = 0; i < 2; i++) {
String uniqueIdStr = String.format("uniqueid-%d", i);
String uniqueIdStr = "uniqueid-%d" + i;

Set<Setting<?>> additionalSettings = extAwarePlugin.getExtensionSettings().stream().collect(Collectors.toSet());
ExtensionScopedSettings extensionScopedSettings = new ExtensionScopedSettings(additionalSettings);
Extension firstExtension = new Extension(
String.format("Extension %s", i),
"Extension %s" + i,
uniqueIdStr,
"127.0.0.0",
"9300",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.extensions.DiscoveryExtensionNode;
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.extensions.ExtensionsSettings;
import org.opensearch.extensions.ExtensionsSettings.Extension;
import org.opensearch.identity.IdentityService;
import org.opensearch.rest.RestRequest;
import org.opensearch.telemetry.tracing.noop.NoopTracer;
Expand Down Expand Up @@ -160,8 +161,8 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettings() th

// optionally, you can stub out some methods:
when(spy.getAdditionalSettings()).thenCallRealMethod();
Mockito.doCallRealMethod().when(spy).loadExtension(any(ExtensionsSettings.Extension.class));
Mockito.doNothing().when(spy).initialize();
Mockito.doCallRealMethod().when(spy).loadExtension(any(Extension.class));
Mockito.doNothing().when(spy).initializeExtensionNode(any(DiscoveryExtensionNode.class));
RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(spy);
final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"ad-extension\",\"hostAddress\":\"127.0.0.1\","
+ "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
Expand All @@ -177,10 +178,10 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettings() th
FakeRestChannel channel = new FakeRestChannel(request, false, 0);
restInitializeExtensionAction.handleRequest(request, channel, null);

assertEquals(channel.capturedResponse().status(), RestStatus.ACCEPTED);
assertEquals(RestStatus.ACCEPTED, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("A request to initialize an extension has been sent."));

Optional<ExtensionsSettings.Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
Optional<Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
assertTrue(extension.isPresent());
assertEquals(true, extension.get().getAdditionalSettings().get(boolSetting));
assertEquals("customSetting", extension.get().getAdditionalSettings().get(stringSetting));
Expand Down Expand Up @@ -210,8 +211,8 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettingsUsing

// optionally, you can stub out some methods:
when(spy.getAdditionalSettings()).thenCallRealMethod();
Mockito.doCallRealMethod().when(spy).loadExtension(any(ExtensionsSettings.Extension.class));
Mockito.doNothing().when(spy).initialize();
Mockito.doCallRealMethod().when(spy).loadExtension(any(Extension.class));
Mockito.doNothing().when(spy).initializeExtensionNode(any(DiscoveryExtensionNode.class));
RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(spy);
final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"ad-extension\",\"hostAddress\":\"127.0.0.1\","
+ "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
Expand All @@ -227,10 +228,10 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettingsUsing
FakeRestChannel channel = new FakeRestChannel(request, false, 0);
restInitializeExtensionAction.handleRequest(request, channel, null);

assertEquals(channel.capturedResponse().status(), RestStatus.ACCEPTED);
assertEquals(RestStatus.ACCEPTED, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("A request to initialize an extension has been sent."));

Optional<ExtensionsSettings.Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
Optional<Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
assertTrue(extension.isPresent());
assertEquals(false, extension.get().getAdditionalSettings().get(boolSetting));
assertEquals("default", extension.get().getAdditionalSettings().get(stringSetting));
Expand Down

0 comments on commit da21f03

Please sign in to comment.