Skip to content

Commit

Permalink
Fixes #2317 predict api not working with asymmetric models (#2318)
Browse files Browse the repository at this point in the history
* Fixes #2317 predict api not working with asymmetric models

Signed-off-by: br3no <breno@veltefaria.de>

* Adding unit test code path for the parsing of the parameter.

Signed-off-by: br3no <breno@veltefaria.de>

* Removing involuntary import of guava

Signed-off-by: br3no <breno@veltefaria.de>

* Refactor package of AsymmetricTextEmbeddingParameters

The MLCommonsClassLoader expects all MLAlgoParameters to be in the
"org.opensearch.ml.common.input.parameter" package.

Signed-off-by: br3no <breno@veltefaria.de>

* fixing unit test after package refactoring

Signed-off-by: br3no <breno@veltefaria.de>

---------

Signed-off-by: br3no <breno@veltefaria.de>
(cherry picked from commit 8425a65)
  • Loading branch information
br3no authored and github-actions[bot] committed Apr 25, 2024
1 parent 213ce30 commit cd9cc9e
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.common.input.nlp;

import java.util.Locale;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
Expand All @@ -13,6 +14,7 @@
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.model.ModelResultFilter;

import java.io.IOException;
Expand Down Expand Up @@ -82,6 +84,7 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws
List<String> docs = new ArrayList<>();
ModelResultFilter resultFilter = null;

MLAlgoParams mlParameters = null;
boolean returnBytes = false;
boolean returnNumber = true;
List<String> targetResponse = new ArrayList<>();
Expand All @@ -93,6 +96,10 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws
parser.nextToken();

switch (fieldName) {
case ML_PARAMETERS_FIELD:
mlParameters = parser.namedObject(MLAlgoParams.class, this.algorithm.name().toUpperCase(
Locale.ROOT), null);
break;
case RETURN_BYTES_FIELD:
returnBytes = parser.booleanValue();
break;
Expand Down Expand Up @@ -137,6 +144,8 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws
throw new IllegalArgumentException("Empty text docs");
}
inputDataset = new TextDocsInputDataSet(docs, filter);

this.parameters = mlParameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.dataset;
package org.opensearch.ml.common.input.parameter.textembedding;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand Down Expand Up @@ -33,7 +33,7 @@
* `query_prefix` and `passage_prefix` configuration parameters.
*/
@Data
@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING })
@MLAlgoParameter(algorithms={FunctionName.TEXT_EMBEDDING})
public class AsymmetricTextEmbeddingParameters implements MLAlgoParams {

public enum EmbeddingContentType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import java.io.IOException;
import java.util.function.Function;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.TestHelper.contentObjectToString;
Expand Down Expand Up @@ -52,7 +53,7 @@ public void parse_AsymmetricTextEmbeddingParameters_Passage() throws IOException
@Test
public void parse_AsymmetricTextEmbeddingParameters_Invalid() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU");
exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU");
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("QUERY","fu"), function);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.opensearch.ml.common.input.nlp;

import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand All @@ -16,6 +18,7 @@
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.search.SearchModule;

import java.io.IOException;
Expand Down Expand Up @@ -65,10 +68,17 @@ public void parseTextDocsMLInput_NewWay() throws IOException {
parseMLInput(jsonStr, 2);
}

@Test
public void parseTextDocsMLInput_WithParameters() throws IOException {
String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}, \"parameters\" : {\"content_type\": \"passage\"}}";
parseMLInput(jsonStr, 2);
}

private void parseMLInput(String jsonStr, int docSize) throws IOException {
XContentParser parser = XContentType.JSON.xContent()
.createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
.createParser(new NamedXContentRegistry(Stream.concat(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents().stream(), Stream.of(
AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY)).collect(Collectors.toList())), null, jsonStr);
parser.nextToken();

MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import java.util.Map;
import java.util.stream.Collectors;

import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
import org.opensearch.ResourceNotFoundException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
import org.opensearch.ml.cluster.MLCommonsClusterEventListener;
import org.opensearch.ml.cluster.MLCommonsClusterManagerEventListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.execute.anomalylocalization.AnomalyLocalizationInput;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
Expand All @@ -103,6 +102,7 @@
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams;
import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.spi.MLCommonsExtension;
import org.opensearch.ml.common.spi.memory.Memory;
Expand Down

0 comments on commit cd9cc9e

Please sign in to comment.