From 13fd30084de717800920653b28c744de899cdf0e Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 10 Oct 2023 13:17:30 -0700 Subject: [PATCH] support multiple docs for remote embedding model Signed-off-by: Yaliang Wu --- .../ml/engine/algorithms/remote/ConnectorUtils.java | 2 +- .../algorithms/remote/RemoteConnectorExecutor.java | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index fb9cb249d7..cbfad6fd5f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -102,7 +102,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat docs.add(null); } } - if (preProcessFunction.contains("${parameters")) { + if (preProcessFunction.contains("${parameters.")) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); preProcessFunction = substitutor.replace(preProcessFunction); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 8712f771c7..176fdb428e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -32,8 +32,15 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - List textDocs = new ArrayList<>(textDocsInputDataSet.getDocs()); - preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs); + int processedDocs = 0; + while(processedDocs < textDocsInputDataSet.getDocs().size()) { + List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); + List tempTensorOutputs = new ArrayList<>(); + preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs); + processedDocs += Math.max(tempTensorOutputs.size(), 1); + tensorOutputs.addAll(tempTensorOutputs); + } + } else { preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs); }