From 9b413a7212743a7820ab2dfb61b3c4b88ead2e12 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Mon, 22 Jul 2024 17:15:38 -0700 Subject: [PATCH] Automated model interface generation on aws llms (#2689) * Automated model interface generation on aws llms Signed-off-by: b4sjoo * Add UTs Signed-off-by: b4sjoo * Add Comments and TODOs Signed-off-by: b4sjoo --------- Signed-off-by: b4sjoo --- .../ml/common/utils/ModelInterfaceUtils.java | 623 ++++++++++++++++++ .../common/utils/ModelInterfaceUtilsTest.java | 215 ++++++ .../TransportRegisterModelAction.java | 13 +- .../opensearch/ml/model/MLModelManager.java | 2 +- .../ml/rest/RestMLGuardrailsIT.java | 16 +- .../ml/rest/RestMLRemoteInferenceIT.java | 12 +- .../ml/tools/ToolIntegrationWithLLMTest.java | 6 +- 7 files changed, 880 insertions(+), 7 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java create mode 100644 common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java new file mode 100644 index 0000000000..f364aba569 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java @@ -0,0 +1,623 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; + +import java.util.Map; + +@Log4j2 +public class ModelInterfaceUtils { + + private static final String GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inputs\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inputs\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; + + private static final String GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"texts\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"texts\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; + + private static final String TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inputText\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inputText\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; + + private static final String TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inputText\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"inputImage\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; + + private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Text\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"Text\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; + + private static final String AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"bytes\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"bytes\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; + + private static final String GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"response\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"name\",\n" + + " \"dataAsMap\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"output\",\n" + + " \"status_code\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inference_results\"\n" + + " ]\n" + + "}"; + + private static final String BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"type\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"completion\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"stop_reason\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"stop\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"type\",\n" + + " \"completion\",\n" + + " \"stop_reason\",\n" + + " \"stop\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"name\",\n" + + " \"dataAsMap\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"output\",\n" + + " \"status_code\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inference_results\"\n" + + " ]\n" + + "}"; + + private static final String GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"data_type\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"shape\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"data\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"name\",\n" + + " \"data_type\",\n" + + " \"shape\",\n" + + " \"data\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"output\",\n" + + " \"status_code\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inference_results\"\n" + + " ]\n" + + "}"; + + private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Languages\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"LanguageCode\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"Score\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"LanguageCode\",\n" + + " \"Score\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"Languages\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"response\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"name\",\n" + + " \"dataAsMap\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"output\",\n" + + " \"status_code\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inference_results\"\n" + + " ]\n" + + "}"; + + private static final String AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Blocks\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"BlockType\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"Geometry\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"BoundingBox\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Height\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"Left\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"Top\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"Width\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"Polygon\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"X\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"Y\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"Id\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"Relationships\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Ids\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"Type\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"DetectDocumentTextModelVersion\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"DocumentMetadata\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Pages\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + + public static final Map BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE = Map.of( + "input", + GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, + "output", + GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT + ); + + public static final Map BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE = Map.of( + "input", + GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, + "output", + GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT + ); + + public static final Map BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE = Map.of( + "input", + GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, + "output", + BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE_OUTPUT + ); + + public static final Map BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE = Map.of( + "input", + GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, + "output", + GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT + ); + + public static final Map BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE = Map.of( + "input", + GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, + "output", + GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT + ); + + public static final Map BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE = Map.of( + "input", + TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT, + "output", + GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT + ); + + public static final Map BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE = Map.of( + "input", + TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT, + "output", + GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT + ); + + public static final Map AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE = Map.of( + "input", + AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_INPUT, + "output", + AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_OUTPUT + ); + + public static final Map AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE = Map.of( + "input", + AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_INPUT, + "output", + AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_OUTPUT + ); + + private static Map createPresetModelInterfaceByConnector(Connector connector) { + if (connector.getParameters() != null) { + switch ((connector.getParameters().get("service_name") != null) ? connector.getParameters().get("service_name") : "null") { + case "bedrock": + log.info("Creating preset model interface for Amazon Bedrock model: {}", connector.getParameters().get("model")); + switch ((connector.getParameters().get("model") != null) ? connector.getParameters().get("model") : "null") { + case "ai21.j2-mid-v1": + return BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE; + case "anthropic.claude-3-sonnet-20240229-v1:0": + return BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE; + case "anthropic.claude-v2": + return BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE; + case "cohere.embed.english-v3": + return BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; + case "cohere.embed.multilingual-v3": + return BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE; + case "amazon.titan-embed-text-v1": + return BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; + case "amazon.titan-embed-image-v1": + return BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; + default: + return null; + } + case "comprehend": + log.info("Creating preset model interface for Amazon Comprehend DetectDominantLanguage API"); + switch ((connector.getParameters().get("api_name") != null) ? connector.getParameters().get("api_name") : "null"){ + // Single case for switch-case statement due to there is one more API in blueprint for Amazon Comprehend Model + // Not set here because there is more than one input/output schema for the DetectEntities API + // TODO: Add default model interface for Amazon Comprehend DetectEntities APIs + case "DetectDominantLanguage": + return AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE; + default: + return null; + } + case "textract": + log.info("Creating preset model interface for Amazon Textract DetectDocumentText API"); + return AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE; + default: + return null; + } + } + return null; + } + + /** + * Update the model interface fields of the register model input based on the stand-alone connector + * @param registerModelInput the register model input + * @param connector the connector + */ + public static void updateRegisterModelInputModelInterfaceFieldsByConnector(MLRegisterModelInput registerModelInput, Connector connector) { + Map presetModelInterface = createPresetModelInterfaceByConnector(connector); + if (presetModelInterface != null) { + registerModelInput.setModelInterface(presetModelInterface); + } + } + + /** + * Update the model interface fields of the register model input based on the internal connector + * @param registerModelInput the register model input + */ + public static void updateRegisterModelInputModelInterfaceFieldsByConnector(MLRegisterModelInput registerModelInput) { + Map presetModelInterface = createPresetModelInterfaceByConnector(registerModelInput.getConnector()); + if (presetModelInterface != null) { + registerModelInput.setModelInterface(presetModelInterface); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java new file mode 100644 index 0000000000..be209ea6ee --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java @@ -0,0 +1,215 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Spy; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector; + +public class ModelInterfaceUtilsTest { + @Spy + MLRegisterModelInput registerModelInputWithInnerConnector; + + @Spy + MLRegisterModelInput registerModelInputWithStandaloneConnector; + + @Spy + public HttpConnector connector; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + registerModelInputWithInnerConnector = MLRegisterModelInput + .builder() + .modelName("test-model-with-inner-connector") + .functionName(FunctionName.REMOTE) + .build(); + + registerModelInputWithStandaloneConnector = MLRegisterModelInput + .builder() + .modelName("test-model-with-stand-alone-connector") + .functionName(FunctionName.REMOTE) + .build(); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "ai21.j2-mid-v1"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "anthropic.claude-3-sonnet-20240229-v1:0"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "anthropic.claude-v2"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "cohere.embed.english-v3"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "cohere.embed.multilingual-v3"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "amazon.titan-embed-text-v1"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "amazon.titan-embed-image-v1"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "comprehend"); + parameters.put("api_name", "DetectDominantLanguage"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "textract"); + parameters.put("api_name", "DetectDocumentText"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorServiceNameNotFound() { + Map parameters = new HashMap<>(); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBedrockModelNameNotFound() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAmazonComprehendAPINameNotFound() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "comprehend"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorNullParameters() { + connector = HttpConnector.builder().protocol("http").build(); + + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); + assertNull(registerModelInputWithStandaloneConnector.getModelInterface()); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorInnerConnectorBEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE() { + Map parameters = new HashMap<>(); + parameters.put("service_name", "bedrock"); + parameters.put("model", "ai21.j2-mid-v1"); + connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); + registerModelInputWithInnerConnector.setConnector(connector); + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithInnerConnector); + assertEquals(registerModelInputWithInnerConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE); + } + + @Test + public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorInnerConnectorNullParameters() { + connector = HttpConnector.builder().protocol("http").build(); + registerModelInputWithInnerConnector.setConnector(connector); + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithInnerConnector); + assertNull(registerModelInputWithInnerConnector.getModelInterface()); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index bde53795a3..66adc8c665 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; +import static org.opensearch.ml.common.utils.ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; @@ -239,7 +240,14 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener< if (Strings.isNotBlank(registerModelInput.getConnectorId())) { connectorAccessControlHelper.validateConnectorAccess(client, registerModelInput.getConnectorId(), ActionListener.wrap(r -> { if (Boolean.TRUE.equals(r)) { - createModelGroup(registerModelInput, listener); + if (registerModelInput.getModelInterface() == null) { + mlModelManager.getConnector(registerModelInput.getConnectorId(), ActionListener.wrap(connector -> { + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInput, connector); + createModelGroup(registerModelInput, listener); + }, listener::onFailure)); + } else { + createModelGroup(registerModelInput, listener); + } } else { listener .onFailure( @@ -261,6 +269,9 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener< validateInternalConnector(registerModelInput); ActionListener dryRunResultListener = ActionListener.wrap(res -> { log.info("Dry run create connector successfully"); + if (registerModelInput.getModelInterface() == null) { + updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInput); + } createModelGroup(registerModelInput, listener); }, e -> { log.error(e.getMessage(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 78d36a9975..daccdf1569 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -1624,7 +1624,7 @@ public void getController(String modelId, ActionListener listener) * @param connectorId connector id * @param listener action listener */ - private void getConnector(String connectorId, ActionListener listener) { + public void getConnector(String connectorId, ActionListener listener) { GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java index 44533d3ae5..6ed0fc4118 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java @@ -366,7 +366,11 @@ protected Response registerRemoteModel(String modelGroupName, String name, Strin + " \"description\": \"test model\",\n" + " \"connector_id\": \"" + connectorId - + "\"\n" + + "\",\n" + + " \"interface\": {\n" + + " \"input\": {},\n" + + " \"output\": {}\n" + + " }\n" + "}"; return TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); @@ -423,7 +427,11 @@ protected Response registerRemoteModelWithLocalRegexGuardrails(String name, Stri + " ],\n" + " \"regex\": [\"regex1\", \"regex2\"]\n" + " }\n" - + " }\n" + + "},\n" + + " \"interface\": {\n" + + " \"input\": {},\n" + + " \"output\": {}\n" + + " }\n" + "}"; return TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); @@ -461,6 +469,10 @@ protected Response registerRemoteModelWithModelGuardrails(String name, String co + " \"connector_id\": \"" + connectorId + "\",\n" + + " \"interface\": {\n" + + " \"input\": {},\n" + + " \"output\": {}\n" + + " },\n" + " \"guardrails\": {\n" + " \"type\": \"model\",\n" + " \"input_guardrail\": {\n" diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 37ff064286..da074fd95c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -814,7 +814,11 @@ public static Response registerRemoteModel(String modelGroupName, String name, S + " \"description\": \"test model\",\n" + " \"connector_id\": \"" + connectorId - + "\"\n" + + "\",\n" + + " \"interface\": {\n" + + " \"input\": {},\n" + + " \"output\": {}\n" + + " }\n" + "}"; return TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); @@ -856,7 +860,11 @@ public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name + " \"deploy_setting\": " + " { \"model_ttl_minutes\": " + ttl - + "}\n" + + "},\n" + + " \"interface\": {\n" + + " \"input\": {},\n" + + " \"output\": {}\n" + + " }\n" + "}"; return TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java index 3e7c2e64f4..b46d2d6a3b 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java @@ -201,7 +201,11 @@ private void setupLLMModel(String connectorId) throws IOException { + " \"description\": \"test model\",\n" + " \"connector_id\": \"" + connectorId - + "\"\n" + + "\",\n" + + " \"interface\": {\n" + + " \"input\": {},\n" + + " \"output\": {}\n" + + " }\n" + "}"; registerModel(client(), input, response -> {