diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 46acf7ac2..720c31a43 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -20,7 +20,6 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.model.MLModelTaskType; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.output.model.ModelTensor; @@ -65,7 +64,7 @@ public void inferenceSentence( /** * Abstraction to call predict function of api of MLClient with default targetResponse filters. It uses the - * custom model provided as modelId and run the {@link MLModelTaskType#TEXT_EMBEDDING}. The return will be sent + * custom model provided as modelId. The return will be sent * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of * inputText. We are not making this function generic enough to take any function or TaskType as currently we * need to run only TextEmbedding tasks only. @@ -84,7 +83,7 @@ public void inferenceSentences( /** * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the - * custom model provided as modelId and run the {@link MLModelTaskType#TEXT_EMBEDDING}. The return will be sent + * custom model provided as modelId. The return will be sent * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of * inputText. We are not making this function generic enough to take any function or TaskType as currently we * need to run only TextEmbedding tasks only. @@ -111,7 +110,7 @@ public void inferenceSentences( private MLInput createMLInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset, MLModelTaskType.TEXT_EMBEDDING); + return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); } private List> buildVectorFromResponse(MLOutput mlOutput) { diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index bfcc4eb86..81523d8d7 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -19,8 +19,8 @@ import org.opensearch.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.model.MLResultDataType; import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors;