Skip to content

Commit 7d25d56

Browse files
Nova mme support (opensearch-project#4360)
* Add Nova MME support Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Fix deleted file Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Add UTs Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> --------- Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> Co-authored-by: Xun Zhang <xunzh@amazon.com>
1 parent a00b7de commit 7d25d56

File tree

8 files changed

+455
-0
lines changed

8 files changed

+455
-0
lines changed

common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public class MLPostProcessFunction {
3131
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
3232
public static final String BEDROCK_V2_EMBEDDING_FLOAT = "connector.post_process.bedrock_v2.embedding.float";
3333
public static final String BEDROCK_V2_EMBEDDING_BINARY = "connector.post_process.bedrock_v2.embedding.binary";
34+
public static final String BEDROCK_NOVA_EMBEDDING = "connector.post_process.bedrock.nova.embedding";
3435
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
3536
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
3637
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
@@ -62,6 +63,7 @@ public class MLPostProcessFunction {
6263
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
6364
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_FLOAT, "$.embeddingsByType.float");
6465
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_BINARY, "$.embeddingsByType.binary");
66+
JSON_PATH_EXPRESSION.put(BEDROCK_NOVA_EMBEDDING, "$.embeddings[*].embedding");
6567
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
6668
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
6769
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
@@ -78,6 +80,7 @@ public class MLPostProcessFunction {
7880
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
7981
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_FLOAT, bedrockEmbeddingPostProcessFunction);
8082
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_BINARY, bedrockEmbeddingPostProcessFunction);
83+
POST_PROCESS_FUNCTIONS.put(BEDROCK_NOVA_EMBEDDING, embeddingPostProcessFunction);
8184
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
8285
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
8386
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);

common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
import java.util.Map;
1010
import java.util.function.Function;
1111

12+
import org.opensearch.ml.common.connector.functions.preprocess.AudioEmbeddingPreProcessFunction;
1213
import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
1314
import org.opensearch.ml.common.connector.functions.preprocess.BedrockRerankPreProcessFunction;
1415
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
1516
import org.opensearch.ml.common.connector.functions.preprocess.CohereMultiModalEmbeddingPreProcessFunction;
1617
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
18+
import org.opensearch.ml.common.connector.functions.preprocess.ImageEmbeddingPreProcessFunction;
1719
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
1820
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
21+
import org.opensearch.ml.common.connector.functions.preprocess.VideoEmbeddingPreProcessFunction;
1922
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
2023
import org.opensearch.ml.common.input.MLInput;
2124

@@ -27,6 +30,10 @@ public class MLPreProcessFunction {
2730
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
2831
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
2932
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding";
33+
public static final String TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.text_embedding";
34+
public static final String IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.image_embedding";
35+
public static final String VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.video_embedding";
36+
public static final String AUDIO_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.audio_embedding";
3037
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
3138
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
3239
public static final String TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT = "connector.pre_process.bedrock.rerank";
@@ -42,11 +49,18 @@ public class MLPreProcessFunction {
4249
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
4350
BedrockRerankPreProcessFunction bedrockRerankPreProcessFunction = new BedrockRerankPreProcessFunction();
4451
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
52+
ImageEmbeddingPreProcessFunction imageEmbeddingPreProcessFunction = new ImageEmbeddingPreProcessFunction();
53+
VideoEmbeddingPreProcessFunction videoEmbeddingPreProcessFunction = new VideoEmbeddingPreProcessFunction();
54+
AudioEmbeddingPreProcessFunction audioEmbeddingPreProcessFunction = new AudioEmbeddingPreProcessFunction();
4555
CohereMultiModalEmbeddingPreProcessFunction cohereMultiModalEmbeddingPreProcessFunction =
4656
new CohereMultiModalEmbeddingPreProcessFunction();
4757
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
4858
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, cohereMultiModalEmbeddingPreProcessFunction);
4959
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
60+
PRE_PROCESS_FUNCTIONS.put(TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
61+
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT, imageEmbeddingPreProcessFunction);
62+
PRE_PROCESS_FUNCTIONS.put(VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT, videoEmbeddingPreProcessFunction);
63+
PRE_PROCESS_FUNCTIONS.put(AUDIO_TO_BEDROCK_NOVA_EMBEDDING_INPUT, audioEmbeddingPreProcessFunction);
5064
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
5165
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
5266
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
9+
10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
15+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
16+
import org.opensearch.ml.common.input.MLInput;
17+
18+
/**
19+
* This class provides a pre-processing function for Bedrock Nova audio input data.
20+
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
21+
* The input data is expected to be of type {@link TextDocsInputDataSet}, with document representing an audio.
22+
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
23+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
24+
*/
25+
public class AudioEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
26+
27+
public AudioEmbeddingPreProcessFunction() {
28+
this.returnDirectlyForRemoteInferenceInput = true;
29+
}
30+
31+
@Override
32+
public void validate(MLInput mlInput) {
33+
validateTextDocsInput(mlInput);
34+
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
35+
if (docs.size() == 0) {
36+
throw new IllegalArgumentException("No input audio provided");
37+
}
38+
}
39+
40+
/**
41+
* @param mlInput The input data to be processed.
42+
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
43+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
44+
*/
45+
@Override
46+
public RemoteInferenceInputDataSet process(MLInput mlInput) {
47+
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
48+
Map<String, String> parametersMap = new HashMap<>();
49+
parametersMap.put("inputAudio", inputData.getDocs().get(0));
50+
return RemoteInferenceInputDataSet
51+
.builder()
52+
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
53+
.build();
54+
55+
}
56+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
9+
10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
15+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
16+
import org.opensearch.ml.common.input.MLInput;
17+
18+
/**
19+
* This class provides a pre-processing function for Bedrock Nova image input data.
20+
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
21+
* The input data is expected to be of type {@link TextDocsInputDataSet}, with document representing an image.
22+
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
23+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
24+
*/
25+
public class ImageEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
26+
27+
public ImageEmbeddingPreProcessFunction() {
28+
this.returnDirectlyForRemoteInferenceInput = true;
29+
}
30+
31+
@Override
32+
public void validate(MLInput mlInput) {
33+
validateTextDocsInput(mlInput);
34+
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
35+
if (docs.size() == 0) {
36+
throw new IllegalArgumentException("No input image provided");
37+
}
38+
}
39+
40+
/**
41+
* @param mlInput The input data to be processed.
42+
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
43+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
44+
*/
45+
@Override
46+
public RemoteInferenceInputDataSet process(MLInput mlInput) {
47+
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
48+
Map<String, String> parametersMap = new HashMap<>();
49+
parametersMap.put("inputImage", inputData.getDocs().get(0));
50+
return RemoteInferenceInputDataSet
51+
.builder()
52+
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
53+
.build();
54+
55+
}
56+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
9+
10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
15+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
16+
import org.opensearch.ml.common.input.MLInput;
17+
18+
/**
19+
* This class provides a pre-processing function for Bedrock Nova video input data.
20+
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
21+
* The input data is expected to be of type {@link TextDocsInputDataSet}, with document representing a video.
22+
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
23+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
24+
*/
25+
public class VideoEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
26+
27+
public VideoEmbeddingPreProcessFunction() {
28+
this.returnDirectlyForRemoteInferenceInput = true;
29+
}
30+
31+
@Override
32+
public void validate(MLInput mlInput) {
33+
validateTextDocsInput(mlInput);
34+
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
35+
if (docs.size() == 0) {
36+
throw new IllegalArgumentException("No input video provided");
37+
}
38+
}
39+
40+
/**
41+
* @param mlInput The input data to be processed.
42+
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
43+
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
44+
*/
45+
@Override
46+
public RemoteInferenceInputDataSet process(MLInput mlInput) {
47+
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
48+
Map<String, String> parametersMap = new HashMap<>();
49+
parametersMap.put("inputVideo", inputData.getDocs().get(0));
50+
return RemoteInferenceInputDataSet
51+
.builder()
52+
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
53+
.build();
54+
55+
}
56+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.mockito.Mockito.mock;
10+
import static org.mockito.Mockito.when;
11+
12+
import java.util.Arrays;
13+
import java.util.Collections;
14+
import java.util.Map;
15+
16+
import org.junit.Before;
17+
import org.junit.Rule;
18+
import org.junit.Test;
19+
import org.junit.rules.ExpectedException;
20+
import org.opensearch.ml.common.FunctionName;
21+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
22+
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
23+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
24+
import org.opensearch.ml.common.input.MLInput;
25+
26+
public class AudioEmbeddingPreProcessFunctionTest {
27+
@Rule
28+
public ExpectedException exceptionRule = ExpectedException.none();
29+
30+
AudioEmbeddingPreProcessFunction function;
31+
32+
TextSimilarityInputDataSet textSimilarityInputDataSet;
33+
TextDocsInputDataSet textDocsInputDataSet;
34+
RemoteInferenceInputDataSet remoteInferenceInputDataSet;
35+
36+
MLInput textEmbeddingInput;
37+
MLInput textSimilarityInput;
38+
MLInput remoteInferenceInput;
39+
40+
@Before
41+
public void setUp() {
42+
function = new AudioEmbeddingPreProcessFunction();
43+
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
44+
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
45+
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build();
46+
47+
textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
48+
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
49+
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build();
50+
}
51+
52+
@Test
53+
public void process_NullInput() {
54+
exceptionRule.expect(IllegalArgumentException.class);
55+
exceptionRule.expectMessage("Preprocess function input can't be null");
56+
function.apply(null);
57+
}
58+
59+
@Test
60+
public void process_WrongInput() {
61+
exceptionRule.expect(IllegalArgumentException.class);
62+
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet");
63+
function.apply(textSimilarityInput);
64+
}
65+
66+
@Test
67+
public void process_CorrectInput() {
68+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
69+
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
70+
assertEquals(1, dataSet.getParameters().size());
71+
assertEquals("hello", dataSet.getParameters().get("inputAudio"));
72+
}
73+
74+
@Test
75+
public void process_EmptyDocs() {
76+
TextDocsInputDataSet mockDataSet = mock(TextDocsInputDataSet.class);
77+
when(mockDataSet.getDocs()).thenReturn(Collections.emptyList());
78+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(mockDataSet).build();
79+
80+
exceptionRule.expect(IllegalArgumentException.class);
81+
exceptionRule.expectMessage("No input audio provided");
82+
function.apply(mlInput);
83+
}
84+
85+
@Test
86+
public void process_RemoteInferenceInput() {
87+
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
88+
assertEquals(remoteInferenceInputDataSet, dataSet);
89+
}
90+
}

0 commit comments

Comments
 (0)