Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class MLPostProcessFunction {
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String BEDROCK_V2_EMBEDDING_FLOAT = "connector.post_process.bedrock_v2.embedding.float";
public static final String BEDROCK_V2_EMBEDDING_BINARY = "connector.post_process.bedrock_v2.embedding.binary";
public static final String BEDROCK_NOVA_EMBEDDING = "connector.post_process.bedrock.nova.embedding";
public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn";
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
Expand Down Expand Up @@ -62,6 +63,7 @@ public class MLPostProcessFunction {
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_FLOAT, "$.embeddingsByType.float");
JSON_PATH_EXPRESSION.put(BEDROCK_V2_EMBEDDING_BINARY, "$.embeddingsByType.binary");
JSON_PATH_EXPRESSION.put(BEDROCK_NOVA_EMBEDDING, "$.embeddings[*].embedding");
JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$");
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
Expand All @@ -78,6 +80,7 @@ public class MLPostProcessFunction {
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_FLOAT, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_V2_EMBEDDING_BINARY, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_NOVA_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import java.util.Map;
import java.util.function.Function;

import org.opensearch.ml.common.connector.functions.preprocess.AudioEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.BedrockRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereMultiModalEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.ImageEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.VideoEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

Expand All @@ -27,6 +30,10 @@ public class MLPreProcessFunction {
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding";
public static final String TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.text_embedding";
public static final String IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.image_embedding";
public static final String VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.video_embedding";
public static final String AUDIO_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.audio_embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
public static final String TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT = "connector.pre_process.bedrock.rerank";
Expand All @@ -42,11 +49,18 @@ public class MLPreProcessFunction {
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
BedrockRerankPreProcessFunction bedrockRerankPreProcessFunction = new BedrockRerankPreProcessFunction();
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
ImageEmbeddingPreProcessFunction imageEmbeddingPreProcessFunction = new ImageEmbeddingPreProcessFunction();
VideoEmbeddingPreProcessFunction videoEmbeddingPreProcessFunction = new VideoEmbeddingPreProcessFunction();
AudioEmbeddingPreProcessFunction audioEmbeddingPreProcessFunction = new AudioEmbeddingPreProcessFunction();
CohereMultiModalEmbeddingPreProcessFunction cohereMultiModalEmbeddingPreProcessFunction =
new CohereMultiModalEmbeddingPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, cohereMultiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT, imageEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT, videoEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(AUDIO_TO_BEDROCK_NOVA_EMBEDDING_INPUT, audioEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

/**
* This class provides a pre-processing function for Bedrock Nova audio input data.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link TextDocsInputDataSet}, with document representing an audio.
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
*/
public class AudioEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public AudioEmbeddingPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.size() == 0) {
throw new IllegalArgumentException("No input audio provided");
}
}

/**
* @param mlInput The input data to be processed.
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
*/
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, String> parametersMap = new HashMap<>();
parametersMap.put("inputAudio", inputData.getDocs().get(0));
return RemoteInferenceInputDataSet
.builder()
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
.build();

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

/**
* This class provides a pre-processing function for Bedrock Nova image input data.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link TextDocsInputDataSet}, with document representing an image.
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
*/
public class ImageEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public ImageEmbeddingPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.size() == 0) {
throw new IllegalArgumentException("No input image provided");
}
}

/**
* @param mlInput The input data to be processed.
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
*/
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, String> parametersMap = new HashMap<>();
parametersMap.put("inputImage", inputData.getDocs().get(0));
return RemoteInferenceInputDataSet
.builder()
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
.build();

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

/**
* This class provides a pre-processing function for Bedrock Nova video input data.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link TextDocsInputDataSet}, with document representing a video.
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
*/
public class VideoEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public VideoEmbeddingPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.size() == 0) {
throw new IllegalArgumentException("No input video provided");
}
}

/**
* @param mlInput The input data to be processed.
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
*/
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, String> parametersMap = new HashMap<>();
parametersMap.put("inputVideo", inputData.getDocs().get(0));
return RemoteInferenceInputDataSet
.builder()
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
.build();

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Collections;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

public class AudioEmbeddingPreProcessFunctionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

AudioEmbeddingPreProcessFunction function;

TextSimilarityInputDataSet textSimilarityInputDataSet;
TextDocsInputDataSet textDocsInputDataSet;
RemoteInferenceInputDataSet remoteInferenceInputDataSet;

MLInput textEmbeddingInput;
MLInput textSimilarityInput;
MLInput remoteInferenceInput;

@Before
public void setUp() {
function = new AudioEmbeddingPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build();

textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build();
}

@Test
public void process_NullInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Preprocess function input can't be null");
function.apply(null);
}

@Test
public void process_WrongInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet");
function.apply(textSimilarityInput);
}

@Test
public void process_CorrectInput() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(1, dataSet.getParameters().size());
assertEquals("hello", dataSet.getParameters().get("inputAudio"));
}

@Test
public void process_EmptyDocs() {
TextDocsInputDataSet mockDataSet = mock(TextDocsInputDataSet.class);
when(mockDataSet.getDocs()).thenReturn(Collections.emptyList());
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(mockDataSet).build();

exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No input audio provided");
function.apply(mlInput);
}

@Test
public void process_RemoteInferenceInput() {
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
assertEquals(remoteInferenceInputDataSet, dataSet);
}
}
Loading
Loading