From 60ea64144ecc9318f87049d5dd42a90c6b144f56 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Mon, 12 Dec 2022 07:47:05 +0000 Subject: [PATCH] Merge second batch of feature/extensions into main Signed-off-by: Ryan Bogan --- .../NamedWriteableRegistryParseRequest.java | 98 ++++ .../NamedWriteableRegistryResponse.java | 91 ++++ ...t.java => InitializeExtensionRequest.java} | 36 +- ....java => InitializeExtensionResponse.java} | 8 +- .../extensions/ExtensionBooleanResponse.java | 68 +++ .../ExtensionNamedWriteableRegistry.java | 148 ++++++ .../extensions/ExtensionReader.java | 45 ++ .../extensions/ExtensionsManager.java | 255 +++++++--- ...WriteableRegistryParseResponseHandler.java | 47 ++ ...NamedWriteableRegistryResponseHandler.java | 142 ++++++ .../extensions/OpenSearchRequest.java | 73 +++ .../RegisterTransportActionsRequest.java | 79 ++++ .../rest/RegisterRestActionsRequest.java | 72 +++ .../rest/RegisterRestActionsResponse.java | 41 ++ .../rest/RestActionsRequestHandler.java | 62 +++ .../rest/RestExecuteOnExtensionRequest.java | 77 +++ .../rest/RestExecuteOnExtensionResponse.java | 112 +++++ .../rest/RestSendToExtensionAction.java | 185 ++++++++ .../extensions/rest/package-info.java | 10 + .../index/AcknowledgedResponse.java | 42 -- .../main/java/org/opensearch/node/Node.java | 4 +- .../extensions/ExtensionsManagerTests.java | 447 ++++++++++++++---- .../RegisterTransportActionsRequestTests.java | 42 ++ .../rest/RegisterRestActionsTests.java | 62 +++ .../rest/RestExecuteOnExtensionTests.java | 94 ++++ .../rest/RestSendToExtensionActionTests.java | 159 +++++++ 26 files changed, 2264 insertions(+), 235 deletions(-) create mode 100644 server/src/main/java/org/opensearch/common/io/stream/NamedWriteableRegistryParseRequest.java create mode 100644 server/src/main/java/org/opensearch/common/io/stream/NamedWriteableRegistryResponse.java rename server/src/main/java/org/opensearch/discovery/{PluginRequest.java => InitializeExtensionRequest.java} (60%) rename server/src/main/java/org/opensearch/discovery/{PluginResponse.java => InitializeExtensionResponse.java} (88%) create mode 100644 server/src/main/java/org/opensearch/extensions/ExtensionBooleanResponse.java create mode 100644 server/src/main/java/org/opensearch/extensions/ExtensionNamedWriteableRegistry.java create mode 100644 server/src/main/java/org/opensearch/extensions/ExtensionReader.java create mode 100644 server/src/main/java/org/opensearch/extensions/NamedWriteableRegistryParseResponseHandler.java create mode 100644 server/src/main/java/org/opensearch/extensions/NamedWriteableRegistryResponseHandler.java create mode 100644 server/src/main/java/org/opensearch/extensions/OpenSearchRequest.java create mode 100644 server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsRequest.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsResponse.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionRequest.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java create mode 100644 server/src/main/java/org/opensearch/extensions/rest/package-info.java delete mode 100644 server/src/main/java/org/opensearch/index/AcknowledgedResponse.java create mode 100644 server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java create mode 100644 server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java create mode 100644 server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java create mode 100644 server/src/test/java/org/opensearch/extensions/rest/RestSendToExtensionActionTests.java diff --git a/server/src/main/java/org/opensearch/common/io/stream/NamedWriteableRegistryParseRequest.java b/server/src/main/java/org/opensearch/common/io/stream/NamedWriteableRegistryParseRequest.java new file mode 100644 index 0000000000000..542de80afbd0d --- /dev/null +++ b/server/src/main/java/org/opensearch/common/io/stream/NamedWriteableRegistryParseRequest.java @@ -0,0 +1,98 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.io.stream; + +import org.opensearch.transport.TransportRequest; +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * Extensibility support for Named Writeable Registry: request to extensions to parse context + * + * @opensearch.internal + * */ +public class NamedWriteableRegistryParseRequest extends TransportRequest { + + private final Class categoryClass; + private byte[] context; + + /** + * @param categoryClass Class category for this parse request + * @param context StreamInput object to convert into a byte array and transport to the extension + * @throws IllegalArgumentException if context bytes could not be read + */ + public NamedWriteableRegistryParseRequest(Class categoryClass, StreamInput context) { + try { + byte[] streamInputBytes = context.readAllBytes(); + this.categoryClass = categoryClass; + this.context = Arrays.copyOf(streamInputBytes, streamInputBytes.length); + } catch (IOException e) { + throw new IllegalArgumentException("Invalid context", e); + } + } + + /** + * @param in StreamInput from which class fields are read from + * @throws IllegalArgumentException if the fully qualified class name is invalid and the class object cannot be generated at runtime + */ + public NamedWriteableRegistryParseRequest(StreamInput in) throws IOException { + super(in); + try { + this.categoryClass = Class.forName(in.readString()); + this.context = in.readByteArray(); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException("Category class definition not found", e); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(categoryClass.getName()); + out.writeByteArray(context); + } + + @Override + public String toString() { + return "NamedWriteableRegistryParseRequest{" + + "categoryClass=" + + categoryClass.getName() + + ", context=" + + context.toString() + + " }"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NamedWriteableRegistryParseRequest that = (NamedWriteableRegistryParseRequest) o; + return Objects.equals(categoryClass, that.categoryClass) && Objects.equals(context, that.context); + } + + @Override + public int hashCode() { + return Objects.hash(categoryClass, context); + } + + /** + * Returns the class instance of the category class sent over by the SDK + */ + public Class getCategoryClass() { + return this.categoryClass; + } + + /** + * Returns a copy of a byte array that a {@link Writeable.Reader} will be applied to. This byte array is generated from a {@link StreamInput} instance and transported to the SDK for deserialization. + */ + public byte[] getContext() { + return Arrays.copyOf(this.context, this.context.length); + } +} diff --git a/server/src/main/java/org/opensearch/common/io/stream/NamedWriteableRegistryResponse.java b/server/src/main/java/org/opensearch/common/io/stream/NamedWriteableRegistryResponse.java new file mode 100644 index 0000000000000..09c46d328c192 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/io/stream/NamedWriteableRegistryResponse.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.io.stream; + +import org.opensearch.transport.TransportResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Extensibility support for Named Writeable Registry: response from extensions for name writeable registry entries + * + * @opensearch.internal + */ +public class NamedWriteableRegistryResponse extends TransportResponse { + + private final Map registry; + + /** + * @param registry Map of writeable names and their associated category class + */ + public NamedWriteableRegistryResponse(Map registry) { + this.registry = new HashMap<>(registry); + } + + /** + * @param in StreamInput from which map entries of writeable names and their associated category classes are read from + * @throws IllegalArgumentException if the fully qualified class name is invalid and the class object cannot be generated at runtime + */ + public NamedWriteableRegistryResponse(StreamInput in) throws IOException { + super(in); + // Stream output for registry map begins with a variable integer that tells us the number of entries being sent across the wire + Map registry = new HashMap<>(); + int registryEntryCount = in.readVInt(); + for (int i = 0; i < registryEntryCount; i++) { + try { + String name = in.readString(); + Class categoryClass = Class.forName(in.readString()); + registry.put(name, categoryClass); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException("Category class definition not found", e); + } + } + + this.registry = registry; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + // Stream out registry size prior to streaming out registry entries + out.writeVInt(this.registry.size()); + for (Map.Entry entry : registry.entrySet()) { + out.writeString(entry.getKey()); // Unique named writeable name + out.writeString(entry.getValue().getName()); // Fully qualified category class name + } + } + + @Override + public String toString() { + return "NamedWritableRegistryResponse{" + "registry=" + registry.toString() + "}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NamedWriteableRegistryResponse that = (NamedWriteableRegistryResponse) o; + return Objects.equals(registry, that.registry); + } + + @Override + public int hashCode() { + return Objects.hash(registry); + } + + /** + * Returns a map of writeable names and their associated category class + */ + public Map getRegistry() { + return Collections.unmodifiableMap(this.registry); + } + +} diff --git a/server/src/main/java/org/opensearch/discovery/PluginRequest.java b/server/src/main/java/org/opensearch/discovery/InitializeExtensionRequest.java similarity index 60% rename from server/src/main/java/org/opensearch/discovery/PluginRequest.java rename to server/src/main/java/org/opensearch/discovery/InitializeExtensionRequest.java index 7992de4342d86..b83e9080fa452 100644 --- a/server/src/main/java/org/opensearch/discovery/PluginRequest.java +++ b/server/src/main/java/org/opensearch/discovery/InitializeExtensionRequest.java @@ -15,62 +15,58 @@ import org.opensearch.transport.TransportRequest; import java.io.IOException; -import java.util.List; import java.util.Objects; /** - * PluginRequest to intialize plugin + * InitializeExtensionRequest to intialize plugin * * @opensearch.internal */ -public class PluginRequest extends TransportRequest { +public class InitializeExtensionRequest extends TransportRequest { private final DiscoveryNode sourceNode; - /* - * TODO change DiscoveryNode to Extension information - */ - private final List extensions; + private final DiscoveryExtensionNode extension; - public PluginRequest(DiscoveryNode sourceNode, List extensions) { + public InitializeExtensionRequest(DiscoveryNode sourceNode, DiscoveryExtensionNode extension) { this.sourceNode = sourceNode; - this.extensions = extensions; + this.extension = extension; } - public PluginRequest(StreamInput in) throws IOException { + public InitializeExtensionRequest(StreamInput in) throws IOException { super(in); sourceNode = new DiscoveryNode(in); - extensions = in.readList(DiscoveryExtensionNode::new); + extension = new DiscoveryExtensionNode(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); sourceNode.writeTo(out); - out.writeList(extensions); - } - - public List getExtensions() { - return extensions; + extension.writeTo(out); } public DiscoveryNode getSourceNode() { return sourceNode; } + public DiscoveryExtensionNode getExtension() { + return extension; + } + @Override public String toString() { - return "PluginRequest{" + "sourceNode=" + sourceNode + ", extensions=" + extensions + '}'; + return "InitializeExtensionsRequest{" + "sourceNode=" + sourceNode + ", extension=" + extension + '}'; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - PluginRequest that = (PluginRequest) o; - return Objects.equals(sourceNode, that.sourceNode) && Objects.equals(extensions, that.extensions); + InitializeExtensionRequest that = (InitializeExtensionRequest) o; + return Objects.equals(sourceNode, that.sourceNode) && Objects.equals(extension, that.extension); } @Override public int hashCode() { - return Objects.hash(sourceNode, extensions); + return Objects.hash(sourceNode, extension); } } diff --git a/server/src/main/java/org/opensearch/discovery/PluginResponse.java b/server/src/main/java/org/opensearch/discovery/InitializeExtensionResponse.java similarity index 88% rename from server/src/main/java/org/opensearch/discovery/PluginResponse.java rename to server/src/main/java/org/opensearch/discovery/InitializeExtensionResponse.java index f8f20214e5846..3be97816dc541 100644 --- a/server/src/main/java/org/opensearch/discovery/PluginResponse.java +++ b/server/src/main/java/org/opensearch/discovery/InitializeExtensionResponse.java @@ -44,14 +44,14 @@ * * @opensearch.internal */ -public class PluginResponse extends TransportResponse { +public class InitializeExtensionResponse extends TransportResponse { private String name; - public PluginResponse(String name) { + public InitializeExtensionResponse(String name) { this.name = name; } - public PluginResponse(StreamInput in) throws IOException { + public InitializeExtensionResponse(StreamInput in) throws IOException { name = in.readString(); } @@ -77,7 +77,7 @@ public String toString() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - PluginResponse that = (PluginResponse) o; + InitializeExtensionResponse that = (InitializeExtensionResponse) o; return Objects.equals(name, that.name); } diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionBooleanResponse.java b/server/src/main/java/org/opensearch/extensions/ExtensionBooleanResponse.java new file mode 100644 index 0000000000000..fc5855ea50a68 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/ExtensionBooleanResponse.java @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportResponse; +import java.io.IOException; +import java.util.Objects; + +/** + * Generic boolean response indicating the status of some previous request sent to the SDK + * + * @opensearch.internal + */ +public class ExtensionBooleanResponse extends TransportResponse { + + private final boolean status; + + /** + * @param status Boolean indicating the status of the parse request sent to the SDK + */ + public ExtensionBooleanResponse(boolean status) { + this.status = status; + } + + public ExtensionBooleanResponse(StreamInput in) throws IOException { + super(in); + this.status = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(status); + } + + @Override + public String toString() { + return "ExtensionBooleanResponse{" + "status=" + this.status + "}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ExtensionBooleanResponse that = (ExtensionBooleanResponse) o; + return Objects.equals(this.status, that.status); + } + + @Override + public int hashCode() { + return Objects.hash(status); + } + + /** + * Returns a boolean indicating the success of the request sent to the SDK + */ + public boolean getStatus() { + return this.status; + } + +} diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionNamedWriteableRegistry.java b/server/src/main/java/org/opensearch/extensions/ExtensionNamedWriteableRegistry.java new file mode 100644 index 0000000000000..28995e2634690 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/ExtensionNamedWriteableRegistry.java @@ -0,0 +1,148 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import java.net.UnknownHostException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.extensions.ExtensionsManager.OpenSearchRequestType; +import org.opensearch.transport.TransportService; + +/** + * API for Named Writeables for extensions + * + * @opensearch.internal + */ +public class ExtensionNamedWriteableRegistry { + + private static final Logger logger = LogManager.getLogger(ExtensionNamedWriteableRegistry.class); + + private Map>> extensionNamedWriteableRegistry; + private List extensions; + private TransportService transportService; + + /** + * Initializes a new ExtensionNamedWriteableRegistry + * + * @param extensions List of DiscoveryExtensions to send requests to + * @param transportService Service that facilitates transport requests + */ + public ExtensionNamedWriteableRegistry(List extensions, TransportService transportService) { + this.extensions = extensions; + this.extensionNamedWriteableRegistry = new HashMap<>(); + this.transportService = transportService; + + getNamedWriteables(); + } + + /** + * Iterates through all discovered extensions, sends transport requests for named writeables and consolidates all entires into a central named writeable registry for extensions. + */ + public void getNamedWriteables() { + // Retrieve named writeable registry entries from each extension + for (DiscoveryNode extensionNode : extensions) { + try { + Map>> extensionRegistry = getNamedWriteables(extensionNode); + if (extensionRegistry.isEmpty() == false) { + this.extensionNamedWriteableRegistry.putAll(extensionRegistry); + } + } catch (UnknownHostException e) { + logger.error(e.toString()); + } + } + + // TODO : Invoke during the consolidation of named writeables within Node.java and return extension entries there + // (https://github.com/opensearch-project/OpenSearch/issues/4067) + } + + /** + * Sends a transport request for named writeables to an extension, identified by the given DiscoveryNode, and processes the response into registry entries + * + * @param extensionNode DiscoveryNode identifying the extension + * @throws UnknownHostException if connection to the extension node failed + * @return A map of category classes and their associated names and readers for this discovery node + */ + private Map>> getNamedWriteables(DiscoveryNode extensionNode) + throws UnknownHostException { + NamedWriteableRegistryResponseHandler namedWriteableRegistryResponseHandler = new NamedWriteableRegistryResponseHandler( + extensionNode, + transportService, + ExtensionsManager.REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE + ); + try { + logger.info("Sending extension request type: " + ExtensionsManager.REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY); + transportService.sendRequest( + extensionNode, + ExtensionsManager.REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY, + new OpenSearchRequest(OpenSearchRequestType.REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY), + namedWriteableRegistryResponseHandler + ); + } catch (Exception e) { + logger.error(e.toString()); + } + + return namedWriteableRegistryResponseHandler.getExtensionRegistry(); + } + + /** + * Iterates through list of discovered extensions and returns the callback method associated with the given category class and name + * + * @param categoryClass Class that the Writeable object extends + * @param name Unique name identifiying the Writeable object + * @throws IllegalArgumentException if there is no reader associated with the given category class and name + * @return A map of the discovery node and its associated extension reader + */ + public Map getExtensionReader(Class categoryClass, String name) { + + ExtensionReader reader = null; + DiscoveryNode extension = null; + + // The specific extension that the reader is associated with is not known, must iterate through all of them + for (DiscoveryNode extensionNode : extensions) { + reader = getExtensionReader(extensionNode, categoryClass, name); + if (reader != null) { + extension = extensionNode; + break; + } + } + + // At this point, if reader does not exist throughout all extensionNodes, named writeable is not registered, throw exception + if (reader == null) { + throw new IllegalArgumentException("Unknown NamedWriteable [" + categoryClass.getName() + "][" + name + "]"); + } + return Collections.singletonMap(extension, reader); + } + + /** + * Returns the callback method associated with the given extension node, category class and name + * + * @param extensionNode Discovery Node identifying the extension associated with the category class and name + * @param categoryClass Class that the Writeable object extends + * @param name Unique name identifying the Writeable object + * @return The extension reader + */ + private ExtensionReader getExtensionReader(DiscoveryNode extensionNode, Class categoryClass, String name) { + ExtensionReader reader = null; + Map> categoryMap = this.extensionNamedWriteableRegistry.get(extensionNode); + if (categoryMap != null) { + Map readerMap = categoryMap.get(categoryClass); + if (readerMap != null) { + reader = readerMap.get(name); + } + } + return reader; + } + +} diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionReader.java b/server/src/main/java/org/opensearch/extensions/ExtensionReader.java new file mode 100644 index 0000000000000..e54e3a6a4f940 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/ExtensionReader.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import java.net.UnknownHostException; +import org.opensearch.cluster.node.DiscoveryNode; + +/** + * Reference to a method that transports a parse request to an extension. By convention, this method takes + * a category class used to identify the reader defined within the JVM that the extension is running on. + * Additionally, this method takes in the extension's corresponding DiscoveryNode and a byte array (context) that the + * extension's reader will be applied to. + * + * By convention the extensions' reader is a constructor that takes StreamInput as an argument for most classes and a static method for things like enums. + * Classes will implement this via a constructor (or a static method in the case of enumerations), it's something that should + * look like: + *

+ * public MyClass(final StreamInput in) throws IOException {
+ * *     this.someValue = in.readVInt();
+ *     this.someMap = in.readMapOfLists(StreamInput::readString, StreamInput::readString);
+ * }
+ * 
+ * + * @opensearch.internal + */ +@FunctionalInterface +public interface ExtensionReader { + + /** + * Transports category class, and StreamInput (context), to the extension identified by the Discovery Node + * + * @param extensionNode Discovery Node identifying the Extension + * @param categoryClass Super class that the reader extends + * @param context Some context to transport + * @throws UnknownHostException if the extension node host IP address could not be determined + */ + void parse(DiscoveryNode extensionNode, Class categoryClass, Object context) throws UnknownHostException; + +} diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java index b809f2e35a483..7314d64f4efef 100644 --- a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java +++ b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java @@ -17,6 +17,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -34,17 +35,19 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; -import org.opensearch.discovery.PluginRequest; -import org.opensearch.discovery.PluginResponse; +import org.opensearch.discovery.InitializeExtensionRequest; +import org.opensearch.discovery.InitializeExtensionResponse; import org.opensearch.extensions.ExtensionsSettings.Extension; +import org.opensearch.extensions.rest.RegisterRestActionsRequest; +import org.opensearch.extensions.rest.RestActionsRequestHandler; import org.opensearch.index.IndexModule; import org.opensearch.index.IndexService; -import org.opensearch.index.AcknowledgedResponse; import org.opensearch.index.IndicesModuleRequest; import org.opensearch.index.IndicesModuleResponse; import org.opensearch.index.shard.IndexEventListener; import org.opensearch.indices.cluster.IndicesClusterStateService; import org.opensearch.plugins.PluginInfo; +import org.opensearch.rest.RestController; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportResponse; @@ -55,7 +58,7 @@ import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; /** - * The main class for Plugin Extensibility + * The main class for managing Extension communication with the OpenSearch Node. * * @opensearch.internal */ @@ -66,6 +69,11 @@ public class ExtensionsManager { public static final String REQUEST_EXTENSION_CLUSTER_STATE = "internal:discovery/clusterstate"; public static final String REQUEST_EXTENSION_LOCAL_NODE = "internal:discovery/localnode"; public static final String REQUEST_EXTENSION_CLUSTER_SETTINGS = "internal:discovery/clustersettings"; + public static final String REQUEST_EXTENSION_REGISTER_REST_ACTIONS = "internal:discovery/registerrestactions"; + public static final String REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY = "internal:discovery/namedwriteableregistry"; + public static final String REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE = "internal:discovery/parsenamedwriteable"; + public static final String REQUEST_REST_EXECUTE_ON_EXTENSION_ACTION = "internal:extensions/restexecuteonextensiontaction"; + public static final String REQUEST_EXTENSION_REGISTER_TRANSPORT_ACTIONS = "internal:discovery/registertransportactions"; private static final Logger logger = LogManager.getLogger(ExtensionsManager.class); @@ -78,29 +86,49 @@ public static enum RequestType { REQUEST_EXTENSION_CLUSTER_STATE, REQUEST_EXTENSION_LOCAL_NODE, REQUEST_EXTENSION_CLUSTER_SETTINGS, + REQUEST_EXTENSION_REGISTER_REST_ACTIONS, CREATE_COMPONENT, ON_INDEX_MODULE, GET_SETTINGS }; + /** + * Enum for OpenSearch Requests + * + * @opensearch.internal + */ + public static enum OpenSearchRequestType { + REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY + } + private final Path extensionsPath; - private final List uninitializedExtensions; + // A list of initialized extensions, a subset of the values of map below which includes all extensions private List extensions; + private Map extensionIdMap; + private RestActionsRequestHandler restActionsRequestHandler; private TransportService transportService; private ClusterService clusterService; + ExtensionNamedWriteableRegistry namedWriteableRegistry; public ExtensionsManager() { this.extensionsPath = Path.of(""); - this.uninitializedExtensions = new ArrayList(); } + /** + * Instantiate a new ExtensionsManager object to handle requests and responses from extensions. This is called during Node bootstrap. + * + * @param settings Settings from the node the orchestrator is running on. + * @param extensionsPath Path to a directory containing extensions. + * @throws IOException If the extensions discovery file is not properly retrieved. + */ public ExtensionsManager(Settings settings, Path extensionsPath) throws IOException { logger.info("ExtensionsManager initialized"); this.extensionsPath = extensionsPath; this.transportService = null; - this.uninitializedExtensions = new ArrayList(); this.extensions = new ArrayList(); + this.extensionIdMap = new HashMap(); this.clusterService = null; + this.namedWriteableRegistry = null; /* * Now Discover extensions @@ -109,16 +137,38 @@ public ExtensionsManager(Settings settings, Path extensionsPath) throws IOExcept } - public void setTransportService(TransportService transportService) { + /** + * Initializes the {@link RestActionsRequestHandler}, {@link TransportService} and {@link ClusterService}. This is called during Node bootstrap. + * Lists/maps of extensions have already been initialized but not yet populated. + * + * @param restController The RestController on which to register Rest Actions. + * @param transportService The Node's transport service. + * @param clusterService The Node's cluster service. + */ + public void initializeServicesAndRestHandler( + RestController restController, + TransportService transportService, + ClusterService clusterService + ) { + this.restActionsRequestHandler = new RestActionsRequestHandler(restController, extensionIdMap, transportService); this.transportService = transportService; + this.clusterService = clusterService; registerRequestHandler(); } - public void setClusterService(ClusterService clusterService) { - this.clusterService = clusterService; + public void setNamedWriteableRegistry() { + this.namedWriteableRegistry = new ExtensionNamedWriteableRegistry(extensions, transportService); } private void registerRequestHandler() { + transportService.registerRequestHandler( + REQUEST_EXTENSION_REGISTER_REST_ACTIONS, + ThreadPool.Names.GENERIC, + false, + false, + RegisterRestActionsRequest::new, + ((request, channel, task) -> channel.sendResponse(restActionsRequestHandler.handleRegisterRestActionsRequest(request))) + ); transportService.registerRequestHandler( REQUEST_EXTENSION_CLUSTER_STATE, ThreadPool.Names.GENERIC, @@ -143,6 +193,14 @@ private void registerRequestHandler() { ExtensionRequest::new, ((request, channel, task) -> channel.sendResponse(handleExtensionRequest(request))) ); + transportService.registerRequestHandler( + REQUEST_EXTENSION_REGISTER_TRANSPORT_ACTIONS, + ThreadPool.Names.GENERIC, + false, + false, + RegisterTransportActionsRequest::new, + ((request, channel, task) -> channel.sendResponse(handleRegisterTransportActionsRequest(request))) + ); } /* @@ -164,7 +222,7 @@ private void discover() throws IOException { for (Extension extension : extensions) { loadExtension(extension); } - if (!uninitializedExtensions.isEmpty()) { + if (!extensionIdMap.isEmpty()) { logger.info("Loaded all extensions"); } } else { @@ -177,9 +235,11 @@ private void discover() throws IOException { * @param extension The extension to be loaded */ private void loadExtension(Extension extension) throws IOException { - try { - uninitializedExtensions.add( - new DiscoveryExtensionNode( + if (extensionIdMap.containsKey(extension.getUniqueId())) { + logger.info("Duplicate uniqueId " + extension.getUniqueId() + ". Did not load extension: " + extension); + } else { + try { + DiscoveryExtensionNode discoveryExtensionNode = new DiscoveryExtensionNode( extension.getName(), extension.getUniqueId(), // placeholder for ephemeral id, will change with POC discovery @@ -199,42 +259,52 @@ private void loadExtension(Extension extension) throws IOException { new ArrayList(), Boolean.parseBoolean(extension.hasNativeController()) ) - ) - ); - logger.info("Loaded extension: " + extension); - } catch (IllegalArgumentException e) { - throw e; + ); + extensionIdMap.put(extension.getUniqueId(), discoveryExtensionNode); + logger.info("Loaded extension with uniqueId " + extension.getUniqueId() + ": " + extension); + } catch (IllegalArgumentException e) { + throw e; + } } } + /** + * Iterate through all extensions and initialize them. Initialized extensions will be added to the {@link #extensions}, and the {@link #namedWriteableRegistry} will be initialized. + */ public void initialize() { - for (DiscoveryNode extensionNode : uninitializedExtensions) { - initializeExtension(extensionNode); + for (DiscoveryExtensionNode extension : extensionIdMap.values()) { + initializeExtension(extension); } + this.namedWriteableRegistry = new ExtensionNamedWriteableRegistry(extensions, transportService); } - private void initializeExtension(DiscoveryNode extensionNode) { + private void initializeExtension(DiscoveryExtensionNode extension) { - final TransportResponseHandler pluginResponseHandler = new TransportResponseHandler() { + final CompletableFuture inProgressFuture = new CompletableFuture<>(); + final TransportResponseHandler initializeExtensionResponseHandler = new TransportResponseHandler< + InitializeExtensionResponse>() { @Override - public PluginResponse read(StreamInput in) throws IOException { - return new PluginResponse(in); + public InitializeExtensionResponse read(StreamInput in) throws IOException { + return new InitializeExtensionResponse(in); } @Override - public void handleResponse(PluginResponse response) { - for (DiscoveryExtensionNode extension : uninitializedExtensions) { + public void handleResponse(InitializeExtensionResponse response) { + for (DiscoveryExtensionNode extension : extensionIdMap.values()) { if (extension.getName().equals(response.getName())) { extensions.add(extension); + logger.info("Initialized extension: " + extension.getName()); break; } } + inProgressFuture.complete(response); } @Override public void handleException(TransportException exp) { - logger.error(new ParameterizedMessage("Plugin request failed"), exp); + logger.error(new ParameterizedMessage("Extension initialization failed"), exp); + inProgressFuture.completeExceptionally(exp); } @Override @@ -243,39 +313,62 @@ public String executor() { } }; try { - transportService.connectToExtensionNode(extensionNode); + logger.info("Sending extension request type: " + REQUEST_EXTENSION_ACTION_NAME); + transportService.connectToExtensionNode(extension); transportService.sendRequest( - extensionNode, + extension, REQUEST_EXTENSION_ACTION_NAME, - new PluginRequest(transportService.getLocalNode(), new ArrayList(uninitializedExtensions)), - pluginResponseHandler + new InitializeExtensionRequest(transportService.getLocalNode(), extension), + initializeExtensionResponseHandler ); + inProgressFuture.get(100, TimeUnit.SECONDS); } catch (Exception e) { - throw e; + try { + throw e; + } catch (Exception e1) { + logger.error(e.toString()); + } } } + /** + * Handles a {@link RegisterTransportActionsRequest}. + * + * @param transportActionsRequest The request to handle. + * @return A {@link ExtensionBooleanResponse} indicating success. + * @throws Exception if the request is not handled properly. + */ + TransportResponse handleRegisterTransportActionsRequest(RegisterTransportActionsRequest transportActionsRequest) throws Exception { + /* + * TODO: https://github.com/opensearch-project/opensearch-sdk-java/issues/107 + * Register these new Transport Actions with ActionModule + * and add support for NodeClient to recognise these actions when making transport calls. + */ + return new ExtensionBooleanResponse(true); + } + + /** + * Handles an {@link ExtensionRequest}. + * + * @param extensionRequest The request to handle, of a type defined in the {@link RequestType} enum. + * @return an Response matching the request. + * @throws Exception if the request is not handled properly. + */ TransportResponse handleExtensionRequest(ExtensionRequest extensionRequest) throws Exception { - // Read enum - if (extensionRequest.getRequestType() == RequestType.REQUEST_EXTENSION_CLUSTER_STATE) { - ClusterStateResponse clusterStateResponse = new ClusterStateResponse( - clusterService.getClusterName(), - clusterService.state(), - false - ); - return clusterStateResponse; - } else if (extensionRequest.getRequestType() == RequestType.REQUEST_EXTENSION_LOCAL_NODE) { - LocalNodeResponse localNodeResponse = new LocalNodeResponse(clusterService); - return localNodeResponse; - } else if (extensionRequest.getRequestType() == RequestType.REQUEST_EXTENSION_CLUSTER_SETTINGS) { - ClusterSettingsResponse clusterSettingsResponse = new ClusterSettingsResponse(clusterService); - return clusterSettingsResponse; + switch (extensionRequest.getRequestType()) { + case REQUEST_EXTENSION_CLUSTER_STATE: + return new ClusterStateResponse(clusterService.getClusterName(), clusterService.state(), false); + case REQUEST_EXTENSION_LOCAL_NODE: + return new LocalNodeResponse(clusterService); + case REQUEST_EXTENSION_CLUSTER_SETTINGS: + return new ClusterSettingsResponse(clusterService); + default: + throw new IllegalStateException("Handler not present for the provided request"); } - throw new IllegalStateException("Handler not present for the provided request: " + extensionRequest.getRequestType()); } public void onIndexModule(IndexModule indexModule) throws UnknownHostException { - for (DiscoveryNode extensionNode : uninitializedExtensions) { + for (DiscoveryNode extensionNode : extensionIdMap.values()) { onIndexModule(indexModule, extensionNode); } } @@ -283,11 +376,11 @@ public void onIndexModule(IndexModule indexModule) throws UnknownHostException { private void onIndexModule(IndexModule indexModule, DiscoveryNode extensionNode) throws UnknownHostException { logger.info("onIndexModule index:" + indexModule.getIndex()); final CompletableFuture inProgressFuture = new CompletableFuture<>(); - final CompletableFuture inProgressIndexNameFuture = new CompletableFuture<>(); - final TransportResponseHandler acknowledgedResponseHandler = new TransportResponseHandler< - AcknowledgedResponse>() { + final CompletableFuture inProgressIndexNameFuture = new CompletableFuture<>(); + final TransportResponseHandler extensionBooleanResponseHandler = new TransportResponseHandler< + ExtensionBooleanResponse>() { @Override - public void handleResponse(AcknowledgedResponse response) { + public void handleResponse(ExtensionBooleanResponse response) { logger.info("ACK Response" + response); inProgressIndexNameFuture.complete(response); } @@ -303,8 +396,8 @@ public String executor() { } @Override - public AcknowledgedResponse read(StreamInput in) throws IOException { - return new AcknowledgedResponse(in); + public ExtensionBooleanResponse read(StreamInput in) throws IOException { + return new ExtensionBooleanResponse(in); } }; @@ -331,20 +424,24 @@ public void beforeIndexRemoved( String indexName = indexService.index().getName(); logger.info("Index Name" + indexName.toString()); try { - logger.info("Sending request of index name to extension"); + logger.info("Sending extension request type: " + INDICES_EXTENSION_NAME_ACTION_NAME); transportService.sendRequest( extensionNode, INDICES_EXTENSION_NAME_ACTION_NAME, new IndicesModuleRequest(indexModule), - acknowledgedResponseHandler + extensionBooleanResponseHandler ); /* - * Making async synchronous for now. + * Making asynchronous for now. */ inProgressIndexNameFuture.get(100, TimeUnit.SECONDS); logger.info("Received ack response from Extension"); } catch (Exception e) { - logger.error(e.toString()); + try { + throw e; + } catch (Exception e1) { + logger.error(e.toString()); + } } } }); @@ -365,7 +462,7 @@ public String executor() { }; try { - logger.info("Sending request to extension"); + logger.info("Sending extension request type: " + INDICES_EXTENSION_POINT_ACTION_NAME); transportService.sendRequest( extensionNode, INDICES_EXTENSION_POINT_ACTION_NAME, @@ -378,7 +475,11 @@ public String executor() { inProgressFuture.get(100, TimeUnit.SECONDS); logger.info("Received response from Extension"); } catch (Exception e) { - logger.error(e.toString()); + try { + throw e; + } catch (Exception e1) { + logger.error(e.toString()); + } } } @@ -421,10 +522,6 @@ public Path getExtensionsPath() { return extensionsPath; } - public List getUninitializedExtensions() { - return uninitializedExtensions; - } - public List getExtensions() { return extensions; } @@ -437,4 +534,32 @@ public ClusterService getClusterService() { return clusterService; } + public static String getRequestExtensionRegisterRestActions() { + return REQUEST_EXTENSION_REGISTER_REST_ACTIONS; + } + + public static String getRequestOpensearchNamedWriteableRegistry() { + return REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY; + } + + public static String getRequestOpensearchParseNamedWriteable() { + return REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE; + } + + public static String getRequestRestExecuteOnExtensionAction() { + return REQUEST_REST_EXECUTE_ON_EXTENSION_ACTION; + } + + public Map getExtensionIdMap() { + return extensionIdMap; + } + + public RestActionsRequestHandler getRestActionsRequestHandler() { + return restActionsRequestHandler; + } + + public ExtensionNamedWriteableRegistry getNamedWriteableRegistry() { + return namedWriteableRegistry; + } + } diff --git a/server/src/main/java/org/opensearch/extensions/NamedWriteableRegistryParseResponseHandler.java b/server/src/main/java/org/opensearch/extensions/NamedWriteableRegistryParseResponseHandler.java new file mode 100644 index 0000000000000..721cfcb7b656f --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/NamedWriteableRegistryParseResponseHandler.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import java.io.IOException; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportResponseHandler; + +/** + * Response handler for NamedWriteableRegistryParse Requests + * + * @opensearch.internal + */ +public class NamedWriteableRegistryParseResponseHandler implements TransportResponseHandler { + private static final Logger logger = LogManager.getLogger(NamedWriteableRegistryParseResponseHandler.class); + + @Override + public ExtensionBooleanResponse read(StreamInput in) throws IOException { + return new ExtensionBooleanResponse(in); + } + + @Override + public void handleResponse(ExtensionBooleanResponse response) { + logger.info("response {}", response.getStatus()); + } + + @Override + public void handleException(TransportException exp) { + logger.error(new ParameterizedMessage("NamedWriteableRegistryParseRequest failed"), exp); + } + + @Override + public String executor() { + return ThreadPool.Names.GENERIC; + } +} diff --git a/server/src/main/java/org/opensearch/extensions/NamedWriteableRegistryResponseHandler.java b/server/src/main/java/org/opensearch/extensions/NamedWriteableRegistryResponseHandler.java new file mode 100644 index 0000000000000..a4e904a3f0f8f --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/NamedWriteableRegistryResponseHandler.java @@ -0,0 +1,142 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import java.io.IOException; +import java.net.UnknownHostException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.NamedWriteableRegistryParseRequest; +import org.opensearch.common.io.stream.NamedWriteableRegistryResponse; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +/** + * Response handler for NamedWriteableRegistry Requests + * + * @opensearch.internal + */ +public class NamedWriteableRegistryResponseHandler implements TransportResponseHandler { + private static final Logger logger = LogManager.getLogger(NamedWriteableRegistryResponseHandler.class); + + private final Map>> extensionRegistry; + private final DiscoveryNode extensionNode; + private final TransportService transportService; + private final String requestType; + + /** + * Instantiates a new NamedWriteableRegistry response handler + * + * @param extensionNode Discovery Node identifying the extension associated with the category class and name + * @param transportService The transport service communicating with the SDK + * @param requestType The type of request that OpenSearch will send to the SDK + */ + public NamedWriteableRegistryResponseHandler(DiscoveryNode extensionNode, TransportService transportService, String requestType) { + this.extensionRegistry = new HashMap(); + this.extensionNode = extensionNode; + this.transportService = transportService; + this.requestType = requestType; + } + + /** + * @return A map of the given DiscoveryNode and its inner named writeable registry map + */ + public Map>> getExtensionRegistry() { + return Collections.unmodifiableMap(this.extensionRegistry); + } + + /** + * Transports a StreamInput, converted into a byte array, and associated category class to the given extension, identified by its discovery node + * + * @param extensionNode Discovery Node identifying the extension associated with the category class and name + * @param categoryClass Class that the Writeable object extends + * @param context StreamInput object to convert into a byte array and transport to the extension + * @throws UnknownHostException if connection to the extension node failed + */ + public void parseNamedWriteable(DiscoveryNode extensionNode, Class categoryClass, StreamInput context) throws UnknownHostException { + NamedWriteableRegistryParseResponseHandler namedWriteableRegistryParseResponseHandler = + new NamedWriteableRegistryParseResponseHandler(); + try { + logger.info("Sending extension request type: " + requestType); + transportService.sendRequest( + extensionNode, + requestType, + new NamedWriteableRegistryParseRequest(categoryClass, context), + namedWriteableRegistryParseResponseHandler + ); + } catch (Exception e) { + logger.error(e.toString()); + } + } + + @Override + public NamedWriteableRegistryResponse read(StreamInput in) throws IOException { + return new NamedWriteableRegistryResponse(in); + } + + @Override + public void handleResponse(NamedWriteableRegistryResponse response) { + + logger.info("response {}", response); + logger.info("EXTENSION [" + extensionNode.getName() + "] returned " + response.getRegistry().size() + " entries"); + + if (response.getRegistry().isEmpty() == false) { + + // Extension has sent over entries to register, initialize inner category map + Map> categoryMap = new HashMap<>(); + + // Reader map associated with this current category + Map readers = null; + Class currentCategory = null; + + for (Map.Entry entry : response.getRegistry().entrySet()) { + + String name = entry.getKey(); + Class categoryClass = entry.getValue(); + if (currentCategory != categoryClass) { + // After first pass, readers and current category are set + if (currentCategory != null) { + categoryMap.put(currentCategory, readers); + } + readers = new HashMap<>(); + currentCategory = categoryClass; + } + + // Add name and callback method reference to inner reader map, + ExtensionReader callBack = (en, cc, context) -> parseNamedWriteable(en, cc, (StreamInput) context); + readers.put(name, callBack); + } + + // Handle last category and reader entry + categoryMap.put(currentCategory, readers); + + // Attach extension node to categoryMap + extensionRegistry.put(extensionNode, categoryMap); + } + } + + @Override + public void handleException(TransportException exp) { + logger.error(new ParameterizedMessage("OpenSearchRequest failed"), exp); + } + + @Override + public String executor() { + return ThreadPool.Names.GENERIC; + } +} diff --git a/server/src/main/java/org/opensearch/extensions/OpenSearchRequest.java b/server/src/main/java/org/opensearch/extensions/OpenSearchRequest.java new file mode 100644 index 0000000000000..62e66f09eb856 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/OpenSearchRequest.java @@ -0,0 +1,73 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.Objects; + +/** + * Request from OpenSearch to an Extension + * + * @opensearch.internal + */ +public class OpenSearchRequest extends TransportRequest { + + private static final Logger logger = LogManager.getLogger(OpenSearchRequest.class); + private ExtensionsManager.OpenSearchRequestType requestType; + + /** + * @param requestType String identifying the default extension point to invoke on the extension + */ + public OpenSearchRequest(ExtensionsManager.OpenSearchRequestType requestType) { + this.requestType = requestType; + } + + /** + * @param in StreamInput from which a string identifying the default extension point to invoke on the extension is read from + */ + public OpenSearchRequest(StreamInput in) throws IOException { + super(in); + this.requestType = in.readEnum(ExtensionsManager.OpenSearchRequestType.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeEnum(requestType); + } + + @Override + public String toString() { + return "OpenSearchRequest{" + "requestType=" + requestType + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OpenSearchRequest that = (OpenSearchRequest) o; + return Objects.equals(requestType, that.requestType); + } + + @Override + public int hashCode() { + return Objects.hash(requestType); + } + + public ExtensionsManager.OpenSearchRequestType getRequestType() { + return this.requestType; + } + +} diff --git a/server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java b/server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java new file mode 100644 index 0000000000000..a3603aaf22dd0 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Request to register extension Transport actions + * + * @opensearch.internal + */ +public class RegisterTransportActionsRequest extends TransportRequest { + private Map transportActions; + + public RegisterTransportActionsRequest(Map transportActions) { + this.transportActions = new HashMap<>(transportActions); + } + + public RegisterTransportActionsRequest(StreamInput in) throws IOException { + super(in); + Map actions = new HashMap<>(); + int actionCount = in.readVInt(); + for (int i = 0; i < actionCount; i++) { + try { + String actionName = in.readString(); + Class transportAction = Class.forName(in.readString()); + actions.put(actionName, transportAction); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException("Could not read transport action"); + } + } + this.transportActions = actions; + } + + public Map getTransportActions() { + return transportActions; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(this.transportActions.size()); + for (Map.Entry action : transportActions.entrySet()) { + out.writeString(action.getKey()); + out.writeString(action.getValue().getName()); + } + } + + @Override + public String toString() { + return "TransportActionsRequest{actions=" + transportActions + "}"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + RegisterTransportActionsRequest that = (RegisterTransportActionsRequest) obj; + return Objects.equals(transportActions, that.transportActions); + } + + @Override + public int hashCode() { + return Objects.hash(transportActions); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsRequest.java b/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsRequest.java new file mode 100644 index 0000000000000..8c190ff416a62 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsRequest.java @@ -0,0 +1,72 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** + * Request to register extension REST actions + * + * @opensearch.internal + */ +public class RegisterRestActionsRequest extends TransportRequest { + private String uniqueId; + private List restActions; + + public RegisterRestActionsRequest(String uniqueId, List restActions) { + this.uniqueId = uniqueId; + this.restActions = new ArrayList<>(restActions); + } + + public RegisterRestActionsRequest(StreamInput in) throws IOException { + super(in); + uniqueId = in.readString(); + restActions = in.readStringList(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(uniqueId); + out.writeStringCollection(restActions); + } + + public String getUniqueId() { + return uniqueId; + } + + public List getRestActions() { + return new ArrayList<>(restActions); + } + + @Override + public String toString() { + return "RestActionsRequest{uniqueId=" + uniqueId + ", restActions=" + restActions + "}"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + RegisterRestActionsRequest that = (RegisterRestActionsRequest) obj; + return Objects.equals(uniqueId, that.uniqueId) && Objects.equals(restActions, that.restActions); + } + + @Override + public int hashCode() { + return Objects.hash(uniqueId, restActions); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsResponse.java b/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsResponse.java new file mode 100644 index 0000000000000..c0a79ad32ce89 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RegisterRestActionsResponse.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportResponse; + +import java.io.IOException; + +/** + * Response to register REST Actions request. + * + * @opensearch.internal + */ +public class RegisterRestActionsResponse extends TransportResponse { + private String response; + + public RegisterRestActionsResponse(String response) { + this.response = response; + } + + public RegisterRestActionsResponse(StreamInput in) throws IOException { + response = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(response); + } + + public String getResponse() { + return response; + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java b/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java new file mode 100644 index 0000000000000..e24f5d519bf81 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.extensions.DiscoveryExtensionNode; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportService; + +import java.util.Map; + +/** + * Handles requests to register extension REST actions. + * + * @opensearch.internal + */ +public class RestActionsRequestHandler { + + private final RestController restController; + private final Map extensionIdMap; + private final TransportService transportService; + + /** + * Instantiates a new REST Actions Request Handler using the Node's RestController. + * + * @param restController The Node's {@link RestController}. + * @param extensionIdMap A map of extension uniqueId to DiscoveryExtensionNode + * @param transportService The Node's transportService + */ + public RestActionsRequestHandler( + RestController restController, + Map extensionIdMap, + TransportService transportService + ) { + this.restController = restController; + this.extensionIdMap = extensionIdMap; + this.transportService = transportService; + } + + /** + * Handles a {@link RegisterRestActionsRequest}. + * + * @param restActionsRequest The request to handle. + * @return A {@link RegisterRestActionsResponse} indicating success. + * @throws Exception if the request is not handled properly. + */ + public TransportResponse handleRegisterRestActionsRequest(RegisterRestActionsRequest restActionsRequest) throws Exception { + DiscoveryExtensionNode discoveryExtensionNode = extensionIdMap.get(restActionsRequest.getUniqueId()); + RestHandler handler = new RestSendToExtensionAction(restActionsRequest, discoveryExtensionNode, transportService); + restController.registerHandler(handler); + return new RegisterRestActionsResponse( + "Registered extension " + restActionsRequest.getUniqueId() + " to handle REST Actions " + restActionsRequest.getRestActions() + ); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionRequest.java b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionRequest.java new file mode 100644 index 0000000000000..128dad2645b42 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionRequest.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.Objects; + +/** + * Request to execute REST actions on extension node + * + * @opensearch.internal + */ +public class RestExecuteOnExtensionRequest extends TransportRequest { + + private Method method; + private String uri; + + public RestExecuteOnExtensionRequest(Method method, String uri) { + this.method = method; + this.uri = uri; + } + + public RestExecuteOnExtensionRequest(StreamInput in) throws IOException { + super(in); + try { + method = RestRequest.Method.valueOf(in.readString()); + } catch (IllegalArgumentException e) { + throw new IOException(e); + } + uri = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(method.name()); + out.writeString(uri); + } + + public Method getMethod() { + return method; + } + + public String getUri() { + return uri; + } + + @Override + public String toString() { + return "RestExecuteOnExtensionRequest{method=" + method + ", uri=" + uri + "}"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + RestExecuteOnExtensionRequest that = (RestExecuteOnExtensionRequest) obj; + return Objects.equals(method, that.method) && Objects.equals(uri, that.uri); + } + + @Override + public int hashCode() { + return Objects.hash(method, uri); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java new file mode 100644 index 0000000000000..b7d7aae3faaab --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java @@ -0,0 +1,112 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.RestStatus; +import org.opensearch.transport.TransportResponse; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * Response to execute REST Actions on the extension node. Wraps the components of a {@link RestResponse}. + * + * @opensearch.internal + */ +public class RestExecuteOnExtensionResponse extends TransportResponse { + private RestStatus status; + private String contentType; + private byte[] content; + private Map> headers; + + /** + * Instantiate this object with a status and response string. + * + * @param status The REST status. + * @param responseString The response content as a String. + */ + public RestExecuteOnExtensionResponse(RestStatus status, String responseString) { + this(status, BytesRestResponse.TEXT_CONTENT_TYPE, responseString.getBytes(StandardCharsets.UTF_8), Collections.emptyMap()); + } + + /** + * Instantiate this object with the components of a {@link RestResponse}. + * + * @param status The REST status. + * @param contentType The type of the content. + * @param content The content. + * @param headers The headers. + */ + public RestExecuteOnExtensionResponse(RestStatus status, String contentType, byte[] content, Map> headers) { + setStatus(status); + setContentType(contentType); + setContent(content); + setHeaders(headers); + } + + /** + * Instantiate this object from a Transport Stream + * + * @param in The stream input. + * @throws IOException on transport failure. + */ + public RestExecuteOnExtensionResponse(StreamInput in) throws IOException { + setStatus(RestStatus.readFrom(in)); + setContentType(in.readString()); + setContent(in.readByteArray()); + setHeaders(in.readMapOfLists(StreamInput::readString, StreamInput::readString)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + RestStatus.writeTo(out, status); + out.writeString(contentType); + out.writeByteArray(content); + out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString); + } + + public RestStatus getStatus() { + return status; + } + + public void setStatus(RestStatus status) { + this.status = status; + } + + public String getContentType() { + return contentType; + } + + public void setContentType(String contentType) { + this.contentType = contentType; + } + + public byte[] getContent() { + return content; + } + + public void setContent(byte[] content) { + this.content = content; + } + + public Map> getHeaders() { + return headers; + } + + public void setHeaders(Map> headers) { + this.headers = Map.copyOf(headers); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java new file mode 100644 index 0000000000000..9b7b81aa946f7 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java @@ -0,0 +1,185 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.extensions.DiscoveryExtensionNode; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.rest.RestStatus; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.unmodifiableList; + +/** + * An action that forwards REST requests to an extension + */ +public class RestSendToExtensionAction extends BaseRestHandler { + + private static final String SEND_TO_EXTENSION_ACTION = "send_to_extension_action"; + private static final Logger logger = LogManager.getLogger(RestSendToExtensionAction.class); + private static final String CONSUMED_PARAMS_KEY = "extension.consumed.parameters"; + + private final List routes; + private final String uriPrefix; + private final DiscoveryExtensionNode discoveryExtensionNode; + private final TransportService transportService; + + /** + * Instantiates this object using a {@link RegisterRestActionsRequest} to populate the routes. + * + * @param restActionsRequest A request encapsulating a list of Strings with the API methods and URIs. + * @param transportService The OpenSearch transport service + * @param discoveryExtensionNode The extension node to which to send actions + */ + public RestSendToExtensionAction( + RegisterRestActionsRequest restActionsRequest, + DiscoveryExtensionNode discoveryExtensionNode, + TransportService transportService + ) { + this.uriPrefix = "/_extensions/_" + restActionsRequest.getUniqueId(); + List restActionsAsRoutes = new ArrayList<>(); + for (String restAction : restActionsRequest.getRestActions()) { + RestRequest.Method method; + String uri; + try { + int delim = restAction.indexOf(' '); + method = RestRequest.Method.valueOf(restAction.substring(0, delim)); + uri = uriPrefix + restAction.substring(delim).trim(); + } catch (IndexOutOfBoundsException | IllegalArgumentException e) { + throw new IllegalArgumentException(restAction + " does not begin with a valid REST method"); + } + logger.info("Registering: " + method + " " + uri); + restActionsAsRoutes.add(new Route(method, uri)); + } + this.routes = unmodifiableList(restActionsAsRoutes); + this.discoveryExtensionNode = discoveryExtensionNode; + this.transportService = transportService; + } + + @Override + public String getName() { + return SEND_TO_EXTENSION_ACTION; + } + + @Override + public List routes() { + return this.routes; + } + + @Override + public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { + Method method = request.getHttpRequest().method(); + String uri = request.getHttpRequest().uri(); + if (uri.startsWith(uriPrefix)) { + uri = uri.substring(uriPrefix.length()); + } + String message = "Forwarding the request " + method + " " + uri + " to " + discoveryExtensionNode; + logger.info(message); + // Initialize response. Values will be changed in the handler. + final RestExecuteOnExtensionResponse restExecuteOnExtensionResponse = new RestExecuteOnExtensionResponse( + RestStatus.INTERNAL_SERVER_ERROR, + BytesRestResponse.TEXT_CONTENT_TYPE, + message.getBytes(StandardCharsets.UTF_8), + emptyMap() + ); + final CompletableFuture inProgressFuture = new CompletableFuture<>(); + final TransportResponseHandler restExecuteOnExtensionResponseHandler = new TransportResponseHandler< + RestExecuteOnExtensionResponse>() { + + @Override + public RestExecuteOnExtensionResponse read(StreamInput in) throws IOException { + return new RestExecuteOnExtensionResponse(in); + } + + @Override + public void handleResponse(RestExecuteOnExtensionResponse response) { + logger.info("Received response from extension: {}", response.getStatus()); + restExecuteOnExtensionResponse.setStatus(response.getStatus()); + restExecuteOnExtensionResponse.setContentType(response.getContentType()); + restExecuteOnExtensionResponse.setContent(response.getContent()); + // Extract the consumed parameters from the header + Map> headers = response.getHeaders(); + List consumedParams = headers.get(CONSUMED_PARAMS_KEY); + if (consumedParams != null) { + consumedParams.stream().forEach(p -> request.param(p)); + } + Map> headersWithoutConsumedParams = headers.entrySet() + .stream() + .filter(e -> !e.getKey().equals(CONSUMED_PARAMS_KEY)) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); + restExecuteOnExtensionResponse.setHeaders(headersWithoutConsumedParams); + inProgressFuture.complete(response); + } + + @Override + public void handleException(TransportException exp) { + logger.error("REST request failed", exp); + // Status is already defaulted to 500 (INTERNAL_SERVER_ERROR) + byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8); + restExecuteOnExtensionResponse.setContent(responseBytes); + inProgressFuture.completeExceptionally(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.GENERIC; + } + }; + try { + transportService.sendRequest( + discoveryExtensionNode, + ExtensionsManager.REQUEST_REST_EXECUTE_ON_EXTENSION_ACTION, + new RestExecuteOnExtensionRequest(method, uri), + restExecuteOnExtensionResponseHandler + ); + try { + inProgressFuture.get(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + return channel -> channel.sendResponse( + new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, "No response from extension to request.") + ); + } + } catch (Exception e) { + logger.info("Failed to send REST Actions to extension " + discoveryExtensionNode.getName(), e); + } + BytesRestResponse restResponse = new BytesRestResponse( + restExecuteOnExtensionResponse.getStatus(), + restExecuteOnExtensionResponse.getContentType(), + restExecuteOnExtensionResponse.getContent() + ); + for (Entry> headerEntry : restExecuteOnExtensionResponse.getHeaders().entrySet()) { + for (String value : headerEntry.getValue()) { + restResponse.addHeader(headerEntry.getKey(), value); + } + } + + return channel -> channel.sendResponse(restResponse); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/package-info.java b/server/src/main/java/org/opensearch/extensions/rest/package-info.java new file mode 100644 index 0000000000000..5a52a295da6ad --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** REST Actions classes for the extensions package. OpenSearch extensions provide extensibility to OpenSearch.*/ +package org.opensearch.extensions.rest; diff --git a/server/src/main/java/org/opensearch/index/AcknowledgedResponse.java b/server/src/main/java/org/opensearch/index/AcknowledgedResponse.java deleted file mode 100644 index 5993a81158d30..0000000000000 --- a/server/src/main/java/org/opensearch/index/AcknowledgedResponse.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.index; - -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.transport.TransportResponse; - -import java.io.IOException; - -/** - * Response for index name of onIndexModule extension point - * - * @opensearch.internal - */ -public class AcknowledgedResponse extends TransportResponse { - private boolean requestAck; - - public AcknowledgedResponse(StreamInput in) throws IOException { - this.requestAck = in.readBoolean(); - } - - public AcknowledgedResponse(Boolean requestAck) { - this.requestAck = requestAck; - } - - public void AcknowledgedResponse(StreamInput in) throws IOException { - this.requestAck = in.readBoolean(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeBoolean(requestAck); - } - -} diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index f204723709965..25b821430ce4e 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -783,6 +783,7 @@ protected Node( modules.add(actionModule); final RestController restController = actionModule.getRestController(); + final NetworkModule networkModule = new NetworkModule( settings, pluginsService.filterPlugins(NetworkPlugin.class), @@ -827,8 +828,7 @@ protected Node( taskHeaders ); if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) { - this.extensionsManager.setTransportService(transportService); - this.extensionsManager.setClusterService(clusterService); + this.extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); } final GatewayMetaState gatewayMetaState = new GatewayMetaState(); final ResponseCollectorService responseCollectorService = new ResponseCollectorService(clusterService); diff --git a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java index cbd86378c0fac..00d521e079f1b 100644 --- a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java +++ b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java @@ -19,7 +19,9 @@ import static org.mockito.Mockito.mock; import static org.opensearch.test.ClusterServiceUtils.createClusterService; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.net.InetAddress; import java.nio.charset.StandardCharsets; import java.nio.file.Files; @@ -30,7 +32,10 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; @@ -38,6 +43,7 @@ import org.junit.Before; import org.opensearch.Version; import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.ClusterSettingsResponse; import org.opensearch.cluster.LocalNodeResponse; import org.opensearch.cluster.metadata.IndexMetadata; @@ -45,7 +51,12 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.PathUtils; +import org.opensearch.common.io.stream.InputStreamStreamInput; +import org.opensearch.common.io.stream.NamedWriteable; import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.io.stream.NamedWriteableRegistryResponse; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; @@ -54,6 +65,8 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.env.Environment; import org.opensearch.env.TestEnvironment; +import org.opensearch.extensions.rest.RegisterRestActionsRequest; +import org.opensearch.extensions.rest.RegisterRestActionsResponse; import org.opensearch.index.IndexModule; import org.opensearch.index.IndexSettings; import org.opensearch.index.analysis.AnalysisRegistry; @@ -61,27 +74,58 @@ import org.opensearch.index.engine.InternalEngineFactory; import org.opensearch.indices.breaker.NoneCircuitBreakerService; import org.opensearch.plugins.PluginInfo; +import org.opensearch.rest.RestController; import org.opensearch.test.IndexSettingsModule; import org.opensearch.test.MockLogAppender; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.transport.MockTransportService; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.ConnectTransportException; import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportService; import org.opensearch.transport.nio.MockNioTransport; +import org.opensearch.usage.UsageService; public class ExtensionsManagerTests extends OpenSearchTestCase { private TransportService transportService; + private RestController restController; private ClusterService clusterService; private MockNioTransport transport; + private Path extensionDir; private final ThreadPool threadPool = new TestThreadPool(ExtensionsManagerTests.class.getSimpleName()); private final Settings settings = Settings.builder() .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) .put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()) .build(); + private final List extensionsYmlLines = Arrays.asList( + "extensions:", + " - name: firstExtension", + " uniqueId: uniqueid1", + " hostName: 'myIndependentPluginHost1'", + " hostAddress: '127.0.0.0'", + " port: '9300'", + " version: '0.0.7'", + " description: Fake description 1", + " opensearchVersion: '3.0.0'", + " javaVersion: '14'", + " className: fakeClass1", + " customFolderName: fakeFolder1", + " hasNativeController: false", + " - name: secondExtension", + " uniqueId: 'uniqueid2'", + " hostName: 'myIndependentPluginHost2'", + " hostAddress: '127.0.0.1'", + " port: '9301'", + " version: '3.14.16'", + " description: Fake description 2", + " opensearchVersion: '2.0.0'", + " javaVersion: '17'", + " className: fakeClass2", + " customFolderName: fakeFolder2", + " hasNativeController: true" + ); @Before public void setup() throws Exception { @@ -112,9 +156,19 @@ public void setup() throws Exception { null, Collections.emptySet() ); + restController = new RestController( + emptySet(), + null, + new NodeClient(Settings.EMPTY, threadPool), + new NoneCircuitBreakerService(), + new UsageService() + ); clusterService = createClusterService(threadPool); + + extensionDir = createTempDir(); } + @Override @After public void tearDown() throws Exception { super.tearDown(); @@ -122,36 +176,9 @@ public void tearDown() throws Exception { ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); } - public void testExtensionsDiscovery() throws Exception { + public void testDiscover() throws Exception { Path extensionDir = createTempDir(); - List extensionsYmlLines = Arrays.asList( - "extensions:", - " - name: firstExtension", - " uniqueId: uniqueid1", - " hostName: 'myIndependentPluginHost1'", - " hostAddress: '127.0.0.0'", - " port: '9300'", - " version: '0.0.7'", - " description: Fake description 1", - " opensearchVersion: '3.0.0'", - " javaVersion: '14'", - " className: fakeClass1", - " customFolderName: fakeFolder1", - " hasNativeController: false", - " - name: secondExtension", - " uniqueId: 'uniqueid2'", - " hostName: 'myIndependentPluginHost2'", - " hostAddress: '127.0.0.1'", - " port: '9301'", - " version: '3.14.16'", - " description: Fake description 2", - " opensearchVersion: '2.0.0'", - " javaVersion: '17'", - " className: fakeClass2", - " customFolderName: fakeFolder2", - " hasNativeController: true" - ); Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); @@ -203,7 +230,48 @@ public void testExtensionsDiscovery() throws Exception { ) ) ); - assertEquals(expectedUninitializedExtensions, extensionsManager.getUninitializedExtensions()); + assertEquals(expectedUninitializedExtensions.size(), extensionsManager.getExtensionIdMap().values().size()); + assertTrue(expectedUninitializedExtensions.containsAll(extensionsManager.getExtensionIdMap().values())); + assertTrue(extensionsManager.getExtensionIdMap().values().containsAll(expectedUninitializedExtensions)); + } + + public void testNonUniqueExtensionsDiscovery() throws Exception { + Path extensionDir = createTempDir(); + + List nonUniqueYmlLines = extensionsYmlLines.stream() + .map(s -> s.replace("uniqueid2", "uniqueid1")) + .collect(Collectors.toList()); + Files.write(extensionDir.resolve("extensions.yml"), nonUniqueYmlLines, StandardCharsets.UTF_8); + + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + + List expectedUninitializedExtensions = new ArrayList(); + + expectedUninitializedExtensions.add( + new DiscoveryExtensionNode( + "firstExtension", + "uniqueid1", + "uniqueid1", + "myIndependentPluginHost1", + "127.0.0.0", + new TransportAddress(InetAddress.getByName("127.0.0.0"), 9300), + new HashMap(), + Version.fromString("3.0.0"), + new PluginInfo( + "firstExtension", + "Fake description 1", + "0.0.7", + Version.fromString("3.0.0"), + "14", + "fakeClass1", + new ArrayList(), + false + ) + ) + ); + assertEquals(expectedUninitializedExtensions.size(), extensionsManager.getExtensionIdMap().values().size()); + assertTrue(expectedUninitializedExtensions.containsAll(extensionsManager.getExtensionIdMap().values())); + assertTrue(extensionsManager.getExtensionIdMap().values().containsAll(expectedUninitializedExtensions)); } public void testNonAccessibleDirectory() throws Exception { @@ -216,8 +284,6 @@ public void testNonAccessibleDirectory() throws Exception { } public void testNoExtensionsFile() throws Exception { - Path extensionDir = createTempDir(); - Settings settings = Settings.builder().build(); try (MockLogAppender mockLogAppender = MockLogAppender.createForLoggers(LogManager.getLogger(ExtensionsManager.class))) { @@ -240,8 +306,8 @@ public void testNoExtensionsFile() throws Exception { public void testEmptyExtensionsFile() throws Exception { Path extensionDir = createTempDir(); - List extensionsYmlLines = Arrays.asList(); - Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); + List emptyExtensionsYmlLines = Arrays.asList(); + Files.write(extensionDir.resolve("extensions.yml"), emptyExtensionsYmlLines, StandardCharsets.UTF_8); Settings settings = Settings.builder().build(); @@ -251,69 +317,114 @@ public void testEmptyExtensionsFile() throws Exception { public void testInitialize() throws Exception { Path extensionDir = createTempDir(); - List extensionsYmlLines = Arrays.asList( - "extensions:", - " - name: firstExtension", - " uniqueId: uniqueid1", - " hostName: 'myIndependentPluginHost1'", - " hostAddress: '127.0.0.0'", - " port: '9300'", - " version: '0.0.7'", - " description: Fake description 1", - " opensearchVersion: '3.0.0'", - " javaVersion: '14'", - " className: fakeClass1", - " customFolderName: fakeFolder1", - " hasNativeController: false", - " - name: secondExtension", - " uniqueId: 'uniqueid2'", - " hostName: 'myIndependentPluginHost2'", - " hostAddress: '127.0.0.1'", - " port: '9301'", - " version: '3.14.16'", - " description: Fake description 2", - " opensearchVersion: '2.0.0'", - " javaVersion: '17'", - " className: fakeClass2", - " customFolderName: fakeFolder2", - " hasNativeController: true" - ); Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); transportService.start(); transportService.acceptIncomingRequests(); - extensionsManager.setTransportService(transportService); + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); - expectThrows(ConnectTransportException.class, () -> extensionsManager.initialize()); + try (MockLogAppender mockLogAppender = MockLogAppender.createForLoggers(LogManager.getLogger(ExtensionsManager.class))) { + + mockLogAppender.addExpectation( + new MockLogAppender.SeenEventExpectation( + "Connect Transport Exception 1", + "org.opensearch.extensions.ExtensionsManager", + Level.ERROR, + "ConnectTransportException[[firstExtension][127.0.0.0:9300] connect_timeout[30s]]" + ) + ); + + mockLogAppender.addExpectation( + new MockLogAppender.SeenEventExpectation( + "Connect Transport Exception 2", + "org.opensearch.extensions.ExtensionsManager", + Level.ERROR, + "ConnectTransportException[[secondExtension][127.0.0.1:9301] connect_exception]; nested: ConnectException[Connection refused];" + ) + ); + extensionsManager.initialize(); + + // Test needs to be changed to mock the connection between the local node and an extension. Assert statment is commented out for + // now. + // Link to issue: https://github.com/opensearch-project/OpenSearch/issues/4045 + // mockLogAppender.assertAllExpectationsMatched(); + } } - public void testHandleExtensionRequest() throws Exception { + public void testHandleRegisterRestActionsRequest() throws Exception { Path extensionDir = createTempDir(); + Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); - extensionsManager.setTransportService(transportService); - extensionsManager.setClusterService(clusterService); + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + String uniqueIdStr = "uniqueid1"; + List actionsList = List.of("GET /foo", "PUT /bar", "POST /baz"); + RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList); + TransportResponse response = extensionsManager.getRestActionsRequestHandler() + .handleRegisterRestActionsRequest(registerActionsRequest); + assertEquals(RegisterRestActionsResponse.class, response.getClass()); + assertTrue(((RegisterRestActionsResponse) response).getResponse().contains(uniqueIdStr)); + assertTrue(((RegisterRestActionsResponse) response).getResponse().contains(actionsList.toString())); + } + + public void testHandleRegisterRestActionsRequestWithInvalidMethod() throws Exception { + + Path extensionDir = createTempDir(); + + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + String uniqueIdStr = "uniqueid1"; + List actionsList = List.of("FOO /foo", "PUT /bar", "POST /baz"); + RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList); + expectThrows( + IllegalArgumentException.class, + () -> extensionsManager.getRestActionsRequestHandler().handleRegisterRestActionsRequest(registerActionsRequest) + ); + } + + public void testHandleRegisterRestActionsRequestWithInvalidUri() throws Exception { + + Path extensionDir = createTempDir(); + + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + String uniqueIdStr = "uniqueid1"; + List actionsList = List.of("GET", "PUT /bar", "POST /baz"); + RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList); + expectThrows( + IllegalArgumentException.class, + () -> extensionsManager.getRestActionsRequestHandler().handleRegisterRestActionsRequest(registerActionsRequest) + ); + } + + public void testHandleExtensionRequest() throws Exception { + + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); ExtensionRequest clusterStateRequest = new ExtensionRequest(ExtensionsManager.RequestType.REQUEST_EXTENSION_CLUSTER_STATE); - assertEquals(extensionsManager.handleExtensionRequest(clusterStateRequest).getClass(), ClusterStateResponse.class); + assertEquals(ClusterStateResponse.class, extensionsManager.handleExtensionRequest(clusterStateRequest).getClass()); ExtensionRequest clusterSettingRequest = new ExtensionRequest(ExtensionsManager.RequestType.REQUEST_EXTENSION_CLUSTER_SETTINGS); - assertEquals(extensionsManager.handleExtensionRequest(clusterSettingRequest).getClass(), ClusterSettingsResponse.class); + assertEquals(ClusterSettingsResponse.class, extensionsManager.handleExtensionRequest(clusterSettingRequest).getClass()); ExtensionRequest localNodeRequest = new ExtensionRequest(ExtensionsManager.RequestType.REQUEST_EXTENSION_LOCAL_NODE); - assertEquals(extensionsManager.handleExtensionRequest(localNodeRequest).getClass(), LocalNodeResponse.class); + assertEquals(LocalNodeResponse.class, extensionsManager.handleExtensionRequest(localNodeRequest).getClass()); ExtensionRequest exceptionRequest = new ExtensionRequest(ExtensionsManager.RequestType.GET_SETTINGS); Exception exception = expectThrows(IllegalStateException.class, () -> extensionsManager.handleExtensionRequest(exceptionRequest)); - assertEquals(exception.getMessage(), "Handler not present for the provided request: " + exceptionRequest.getRequestType()); + assertEquals("Handler not present for the provided request", exception.getMessage()); } public void testRegisterHandler() throws Exception { - Path extensionDir = createTempDir(); ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); @@ -329,49 +440,181 @@ public void testRegisterHandler() throws Exception { ) ); - extensionsManager.setTransportService(mockTransportService); - verify(mockTransportService, times(3)).registerRequestHandler(anyString(), anyString(), anyBoolean(), anyBoolean(), any(), any()); + extensionsManager.initializeServicesAndRestHandler(restController, mockTransportService, clusterService); + verify(mockTransportService, times(5)).registerRequestHandler(anyString(), anyString(), anyBoolean(), anyBoolean(), any(), any()); } - public void testOnIndexModule() throws Exception { + private static class Example implements NamedWriteable { + public static final String INVALID_NAME = "invalid_name"; + public static final String NAME = "example"; + private final String message; - Path extensionDir = createTempDir(); + Example(String message) { + this.message = message; + } + + Example(StreamInput in) throws IOException { + this.message = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(message); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + Example that = (Example) o; + return Objects.equals(message, that.message); + } - List extensionsYmlLines = Arrays.asList( - "extensions:", - " - name: firstExtension", - " uniqueId: uniqueid1", - " hostName: 'myIndependentPluginHost1'", - " hostAddress: '127.0.0.0'", - " port: '9300'", - " version: '0.0.7'", - " description: Fake description 1", - " opensearchVersion: '3.0.0'", - " javaVersion: '14'", - " className: fakeClass1", - " customFolderName: fakeFolder1", - " hasNativeController: false", - " - name: secondExtension", - " uniqueId: 'uniqueid2'", - " hostName: 'myIndependentPluginHost2'", - " hostAddress: '127.0.0.1'", - " port: '9301'", - " version: '3.14.16'", - " description: Fake description 2", - " opensearchVersion: '2.0.0'", - " javaVersion: '17'", - " className: fakeClass2", - " customFolderName: fakeFolder2", - " hasNativeController: true" + @Override + public int hashCode() { + return Objects.hash(message); + } + } + + public void testGetNamedWriteables() throws Exception { + Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + transportService.start(); + transportService.acceptIncomingRequests(); + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + + try ( + MockLogAppender mockLogAppender = MockLogAppender.createForLoggers( + LogManager.getLogger(NamedWriteableRegistryResponseHandler.class) + ) + ) { + + mockLogAppender.addExpectation( + new MockLogAppender.SeenEventExpectation( + "OpenSearchRequest Failure", + "org.opensearch.extensions.NamedWriteableRegistryResponseHandler", + Level.ERROR, + "OpenSearchRequest failed" + ) + ); + + List extensionsList = new ArrayList<>(extensionsManager.getExtensionIdMap().values()); + extensionsManager.namedWriteableRegistry = new ExtensionNamedWriteableRegistry(extensionsList, transportService); + extensionsManager.namedWriteableRegistry.getNamedWriteables(); + mockLogAppender.assertAllExpectationsMatched(); + } + } + + public void testNamedWriteableRegistryResponseHandler() throws Exception { + Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + transportService.start(); + transportService.acceptIncomingRequests(); + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + + List extensionsList = new ArrayList<>(extensionsManager.getExtensionIdMap().values()); + DiscoveryNode extensionNode = extensionsList.get(0); + String requestType = ExtensionsManager.REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY; + + // Create response to pass to response handler + Map responseRegistry = new HashMap<>(); + responseRegistry.put(Example.NAME, Example.class); + NamedWriteableRegistryResponse response = new NamedWriteableRegistryResponse(responseRegistry); + + NamedWriteableRegistryResponseHandler responseHandler = new NamedWriteableRegistryResponseHandler( + extensionNode, + transportService, + requestType ); + responseHandler.handleResponse(response); + + // Ensure that response entries have been processed correctly into their respective maps + Map>> extensionsRegistry = responseHandler.getExtensionRegistry(); + assertEquals(extensionsRegistry.size(), 1); + + Map> categoryMap = extensionsRegistry.get(extensionNode); + assertEquals(categoryMap.size(), 1); + + Map readerMap = categoryMap.get(Example.class); + assertEquals(readerMap.size(), 1); + + ExtensionReader callback = readerMap.get(Example.NAME); + assertNotNull(callback); + } + + public void testGetExtensionReader() throws IOException { + Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + + extensionsManager.namedWriteableRegistry = spy( + new ExtensionNamedWriteableRegistry(extensionsManager.getExtensions(), transportService) + ); + + Exception e = expectThrows( + Exception.class, + () -> extensionsManager.namedWriteableRegistry.getExtensionReader(Example.class, Example.NAME) + ); + assertEquals(e.getMessage(), "Unknown NamedWriteable [" + Example.class.getName() + "][" + Example.NAME + "]"); + verify(extensionsManager.namedWriteableRegistry, times(1)).getExtensionReader(any(), any()); + } + + public void testParseNamedWriteables() throws Exception { + Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); + ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); + transportService.start(); + transportService.acceptIncomingRequests(); + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); + + String requestType = ExtensionsManager.REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE; + List extensionsList = new ArrayList<>(extensionsManager.getExtensionIdMap().values()); + DiscoveryNode extensionNode = extensionsList.get(0); + Class categoryClass = Example.class; + + // convert context into an input stream then stream input for mock + byte[] context = new byte[0]; + InputStream inputStream = new ByteArrayInputStream(context); + StreamInput in = new InputStreamStreamInput(inputStream); + + try ( + MockLogAppender mockLogAppender = MockLogAppender.createForLoggers( + LogManager.getLogger(NamedWriteableRegistryParseResponseHandler.class) + ) + ) { + + mockLogAppender.addExpectation( + new MockLogAppender.SeenEventExpectation( + "NamedWriteableRegistryParseRequest Failure", + "org.opensearch.extensions.NamedWriteableRegistryParseResponseHandler", + Level.ERROR, + "NamedWriteableRegistryParseRequest failed" + ) + ); + + NamedWriteableRegistryResponseHandler responseHandler = new NamedWriteableRegistryResponseHandler( + extensionNode, + transportService, + requestType + ); + responseHandler.parseNamedWriteable(extensionNode, categoryClass, in); + mockLogAppender.assertAllExpectationsMatched(); + } + } + + public void testOnIndexModule() throws Exception { Files.write(extensionDir.resolve("extensions.yml"), extensionsYmlLines, StandardCharsets.UTF_8); ExtensionsManager extensionsManager = new ExtensionsManager(settings, extensionDir); transportService.start(); transportService.acceptIncomingRequests(); - extensionsManager.setTransportService(transportService); + extensionsManager.initializeServicesAndRestHandler(restController, transportService, clusterService); Environment environment = TestEnvironment.newEnvironment(settings); AnalysisRegistry emptyAnalysisRegistry = new AnalysisRegistry( diff --git a/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java b/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java new file mode 100644 index 0000000000000..ed36cc5290bb1 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions; + +import org.junit.Before; +import org.opensearch.common.collect.Map; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +public class RegisterTransportActionsRequestTests extends OpenSearchTestCase { + private RegisterTransportActionsRequest originalRequest; + + @Before + public void setup() { + this.originalRequest = new RegisterTransportActionsRequest(Map.of("testAction", Map.class)); + } + + public void testRegisterTransportActionsRequest() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + originalRequest.writeTo(output); + StreamInput input = output.bytes().streamInput(); + RegisterTransportActionsRequest parsedRequest = new RegisterTransportActionsRequest(input); + assertEquals(parsedRequest.getTransportActions(), originalRequest.getTransportActions()); + assertEquals(parsedRequest.getTransportActions().get("testAction"), originalRequest.getTransportActions().get("testAction")); + assertEquals(parsedRequest.getTransportActions().size(), originalRequest.getTransportActions().size()); + assertEquals(parsedRequest.hashCode(), originalRequest.hashCode()); + assertTrue(originalRequest.equals(parsedRequest)); + } + + public void testToString() { + assertEquals(originalRequest.toString(), "TransportActionsRequest{actions={testAction=class org.opensearch.common.collect.Map}}"); + } +} diff --git a/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java b/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java new file mode 100644 index 0000000000000..a8f1739ce82f2 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import java.util.List; + +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.test.OpenSearchTestCase; + +public class RegisterRestActionsTests extends OpenSearchTestCase { + + public void testRegisterRestActionsRequest() throws Exception { + String uniqueIdStr = "uniqueid1"; + List expected = List.of("GET /foo", "PUT /bar", "POST /baz"); + RegisterRestActionsRequest registerRestActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, expected); + + assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId()); + List restActions = registerRestActionsRequest.getRestActions(); + assertEquals(expected.size(), restActions.size()); + assertTrue(restActions.containsAll(expected)); + assertTrue(expected.containsAll(restActions)); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + registerRestActionsRequest.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + registerRestActionsRequest = new RegisterRestActionsRequest(in); + + assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId()); + restActions = registerRestActionsRequest.getRestActions(); + assertEquals(expected.size(), restActions.size()); + assertTrue(restActions.containsAll(expected)); + assertTrue(expected.containsAll(restActions)); + } + } + } + + public void testRegisterRestActionsResponse() throws Exception { + String response = "This is a response"; + RegisterRestActionsResponse registerRestActionsResponse = new RegisterRestActionsResponse(response); + + assertEquals(response, registerRestActionsResponse.getResponse()); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + registerRestActionsResponse.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + registerRestActionsResponse = new RegisterRestActionsResponse(in); + + assertEquals(response, registerRestActionsResponse.getResponse()); + } + } + } +} diff --git a/server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java b/server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java new file mode 100644 index 0000000000000..98521ddcf1e26 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/rest/RestExecuteOnExtensionTests.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.rest.RestStatus; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.test.OpenSearchTestCase; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +public class RestExecuteOnExtensionTests extends OpenSearchTestCase { + + public void testRestExecuteOnExtensionRequest() throws Exception { + Method expectedMethod = Method.GET; + String expectedUri = "/test/uri"; + RestExecuteOnExtensionRequest request = new RestExecuteOnExtensionRequest(expectedMethod, expectedUri); + + assertEquals(expectedMethod, request.getMethod()); + assertEquals(expectedUri, request.getUri()); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + request.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + request = new RestExecuteOnExtensionRequest(in); + + assertEquals(expectedMethod, request.getMethod()); + assertEquals(expectedUri, request.getUri()); + } + } + } + + public void testRestExecuteOnExtensionResponse() throws Exception { + RestStatus expectedStatus = RestStatus.OK; + String expectedContentType = BytesRestResponse.TEXT_CONTENT_TYPE; + String expectedResponse = "Test response"; + byte[] expectedResponseBytes = expectedResponse.getBytes(StandardCharsets.UTF_8); + + RestExecuteOnExtensionResponse response = new RestExecuteOnExtensionResponse(expectedStatus, expectedResponse); + + assertEquals(expectedStatus, response.getStatus()); + assertEquals(expectedContentType, response.getContentType()); + assertArrayEquals(expectedResponseBytes, response.getContent()); + assertEquals(0, response.getHeaders().size()); + + String headerKey = "foo"; + List headerValueList = List.of("bar", "baz"); + Map> expectedHeaders = Map.of(headerKey, headerValueList); + + response = new RestExecuteOnExtensionResponse(expectedStatus, expectedContentType, expectedResponseBytes, expectedHeaders); + + assertEquals(expectedStatus, response.getStatus()); + assertEquals(expectedContentType, response.getContentType()); + assertArrayEquals(expectedResponseBytes, response.getContent()); + + assertEquals(1, expectedHeaders.keySet().size()); + assertTrue(expectedHeaders.containsKey(headerKey)); + + List fooList = expectedHeaders.get(headerKey); + assertEquals(2, fooList.size()); + assertTrue(fooList.containsAll(headerValueList)); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + response.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + response = new RestExecuteOnExtensionResponse(in); + + assertEquals(expectedStatus, response.getStatus()); + assertEquals(expectedContentType, response.getContentType()); + assertArrayEquals(expectedResponseBytes, response.getContent()); + + assertEquals(1, expectedHeaders.keySet().size()); + assertTrue(expectedHeaders.containsKey(headerKey)); + + fooList = expectedHeaders.get(headerKey); + assertEquals(2, fooList.size()); + assertTrue(fooList.containsAll(headerValueList)); + } + } + } +} diff --git a/server/src/test/java/org/opensearch/extensions/rest/RestSendToExtensionActionTests.java b/server/src/test/java/org/opensearch/extensions/rest/RestSendToExtensionActionTests.java new file mode 100644 index 0000000000000..2a593a8d251e9 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/rest/RestSendToExtensionActionTests.java @@ -0,0 +1,159 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.network.NetworkService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.util.PageCacheRecycler; +import org.opensearch.extensions.DiscoveryExtensionNode; +import org.opensearch.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.plugins.PluginInfo; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.transport.MockTransportService; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.nio.MockNioTransport; + +public class RestSendToExtensionActionTests extends OpenSearchTestCase { + + private TransportService transportService; + private MockNioTransport transport; + private DiscoveryExtensionNode discoveryExtensionNode; + private final ThreadPool threadPool = new TestThreadPool(RestSendToExtensionActionTests.class.getSimpleName()); + + @Before + public void setup() throws Exception { + Settings settings = Settings.builder().put("cluster.name", "test").build(); + transport = new MockNioTransport( + settings, + Version.CURRENT, + threadPool, + new NetworkService(Collections.emptyList()), + PageCacheRecycler.NON_RECYCLING_INSTANCE, + new NamedWriteableRegistry(Collections.emptyList()), + new NoneCircuitBreakerService() + ); + transportService = new MockTransportService( + settings, + transport, + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + (boundAddress) -> new DiscoveryNode( + "test_node", + "test_node", + boundAddress.publishAddress(), + emptyMap(), + emptySet(), + Version.CURRENT + ), + null, + Collections.emptySet() + ); + discoveryExtensionNode = new DiscoveryExtensionNode( + "firstExtension", + "uniqueid1", + "uniqueid1", + "myIndependentPluginHost1", + "127.0.0.0", + new TransportAddress(InetAddress.getByName("127.0.0.0"), 9300), + new HashMap(), + Version.fromString("3.0.0"), + new PluginInfo( + "firstExtension", + "Fake description 1", + "0.0.7", + Version.fromString("3.0.0"), + "14", + "fakeClass1", + new ArrayList(), + false + ) + ); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + transportService.close(); + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + + public void testRestSendToExtensionAction() throws Exception { + RegisterRestActionsRequest registerRestActionRequest = new RegisterRestActionsRequest( + "uniqueid1", + List.of("GET /foo", "PUT /bar", "POST /baz") + ); + RestSendToExtensionAction restSendToExtensionAction = new RestSendToExtensionAction( + registerRestActionRequest, + discoveryExtensionNode, + transportService + ); + + assertEquals("send_to_extension_action", restSendToExtensionAction.getName()); + List expected = new ArrayList<>(); + String uriPrefix = "/_extensions/_uniqueid1"; + expected.add(new Route(Method.GET, uriPrefix + "/foo")); + expected.add(new Route(Method.PUT, uriPrefix + "/bar")); + expected.add(new Route(Method.POST, uriPrefix + "/baz")); + + List routes = restSendToExtensionAction.routes(); + assertEquals(expected.size(), routes.size()); + List expectedPaths = expected.stream().map(Route::getPath).collect(Collectors.toList()); + List paths = routes.stream().map(Route::getPath).collect(Collectors.toList()); + List expectedMethods = expected.stream().map(Route::getMethod).collect(Collectors.toList()); + List methods = routes.stream().map(Route::getMethod).collect(Collectors.toList()); + assertTrue(paths.containsAll(expectedPaths)); + assertTrue(expectedPaths.containsAll(paths)); + assertTrue(methods.containsAll(expectedMethods)); + assertTrue(expectedMethods.containsAll(methods)); + } + + public void testRestSendToExtensionActionBadMethod() throws Exception { + RegisterRestActionsRequest registerRestActionRequest = new RegisterRestActionsRequest( + "uniqueid1", + List.of("/foo", "PUT /bar", "POST /baz") + ); + expectThrows( + IllegalArgumentException.class, + () -> new RestSendToExtensionAction(registerRestActionRequest, discoveryExtensionNode, transportService) + ); + } + + public void testRestSendToExtensionActionMissingUri() throws Exception { + RegisterRestActionsRequest registerRestActionRequest = new RegisterRestActionsRequest( + "uniqueid1", + List.of("GET", "PUT /bar", "POST /baz") + ); + expectThrows( + IllegalArgumentException.class, + () -> new RestSendToExtensionAction(registerRestActionRequest, discoveryExtensionNode, transportService) + ); + } +}