From 365e0a74b34d3f01781a423b0e48a55e6cdc640e Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 11 Oct 2022 10:05:57 -0700 Subject: [PATCH] Address comments Signed-off-by: John Mazanec --- build.gradle | 1 + .../ml/MLCommonsClientAccessor.java | 2 +- .../plugin/query/NeuralQueryBuilder.java | 72 +++--------- .../plugin/query/NeuralQueryBuilderTests.java | 109 ++++++++++-------- 4 files changed, 80 insertions(+), 104 deletions(-) diff --git a/build.gradle b/build.gradle index 6c29f7cc2..3b52deafb 100644 --- a/build.gradle +++ b/build.gradle @@ -119,6 +119,7 @@ allprojects { } repositories { + mavenLocal() maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } mavenCentral() maven { url "https://plugins.gradle.org/m2/" } diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 7efb432ec..5a6c4775e 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -74,7 +74,7 @@ public void inferenceSentences( ) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - final MLInput mlInput = new MLInput(FunctionName.CUSTOM, null, inputDataset, MLModelTaskType.TEXT_EMBEDDING); + final MLInput mlInput = new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset, MLModelTaskType.TEXT_EMBEDDING); final List> vector = new ArrayList<>(); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilder.java index baecd984b..91f2bd76f 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilder.java @@ -9,6 +9,8 @@ import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.builder.EqualsBuilder; @@ -24,6 +26,8 @@ import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryShardContext; +import com.google.common.annotations.VisibleForTesting; + /** * NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a wrapper around a * k-NN vector query. It uses a ML language model to produce a dense vector from a query string that is then used as @@ -32,65 +36,28 @@ @Log4j2 @Getter +@Setter +@Accessors(chain = true, fluent = true) @NoArgsConstructor public class NeuralQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "neural"; + @VisibleForTesting static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); + @VisibleForTesting static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); + @VisibleForTesting static final ParseField K_FIELD = new ParseField("k"); + private static int DEFAULT_K = 10; + private String fieldName; private String queryText; private String modelId; - private int k; - - /** - * Set the fieldName this query will be executed against - * - * @param fieldName name of k-NN vector field that query will be executed against - * @return this - */ - public NeuralQueryBuilder fieldName(String fieldName) { - this.fieldName = fieldName; - return this; - } - - /** - * Set the queryText that will be translated into the dense query vector used for k-NN search. - * - * @param queryText Text of a query that should be translated to a dense vector - * @return this - */ - public NeuralQueryBuilder queryText(String queryText) { - this.queryText = queryText; - return this; - } - - /** - * Set the modelId that should produce the dense query vector - * - * @param modelId ID of model to produce query vector - * @return this - */ - public NeuralQueryBuilder modelId(String modelId) { - this.modelId = modelId; - return this; - } - - /** - * Set the number of neighbors that should be retrieved during k-NN search - * - * @param k number of neighbors to be retrieved in k-NN query - * @return this - */ - public NeuralQueryBuilder k(int k) { - this.k = k; - return this; - } + private int k = DEFAULT_K; /** * Constructor from stream input @@ -103,7 +70,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { this.fieldName = in.readString(); this.queryText = in.readString(); this.modelId = in.readString(); - this.k = in.readInt(); + this.k = in.readVInt(); } @Override @@ -111,7 +78,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(this.fieldName); out.writeString(this.queryText); out.writeString(this.modelId); - out.writeInt(this.k); + out.writeVInt(this.k); } @Override @@ -157,11 +124,9 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx if (parser.currentToken() != XContentParser.Token.END_OBJECT) { throw new ParsingException(parser.getTokenLocation(), "Token must be END_OBJECT"); } - - requireValue(neuralQueryBuilder.getQueryText(), "Query text must be provided for neural query"); - requireValue(neuralQueryBuilder.getFieldName(), "Field name must be provided for neural query"); - requireValue(neuralQueryBuilder.getModelId(), "Model ID must be provided for neural query"); - requireValue(neuralQueryBuilder.getK(), "K must be provided for neural query"); + requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query"); + requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query"); + requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query"); return neuralQueryBuilder; } @@ -199,7 +164,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n } @Override - protected Query doToQuery(QueryShardContext queryShardContext) throws IOException { + protected Query doToQuery(QueryShardContext queryShardContext) { + // TODO Implement logic to build KNNQuery in this method return null; } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilderTests.java index 96fc7050b..eac7020bc 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilderTests.java @@ -16,6 +16,8 @@ import java.io.IOException; import java.util.Map; +import lombok.SneakyThrows; + import org.opensearch.common.ParsingException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.ToXContent; @@ -33,7 +35,8 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase { private static final float BOOST = 1.8f; private static final String QUERY_NAME = "queryName"; - public void testFromXContent_valid_withDefaults() throws IOException { + @SneakyThrows + public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { /* { "VECTOR_FIELD": { @@ -55,13 +58,14 @@ public void testFromXContent_valid_withDefaults() throws IOException { XContentParser contentParser = createParser(xContentBuilder); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); - assertEquals(FIELD_NAME, neuralQueryBuilder.getFieldName()); - assertEquals(QUERY_TEXT, neuralQueryBuilder.getQueryText()); - assertEquals(MODEL_ID, neuralQueryBuilder.getModelId()); - assertEquals(K, neuralQueryBuilder.getK()); + assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(K, neuralQueryBuilder.k()); } - public void testFromXContent_valid_withOptionals() throws IOException { + @SneakyThrows + public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { /* { "VECTOR_FIELD": { @@ -87,15 +91,16 @@ public void testFromXContent_valid_withOptionals() throws IOException { XContentParser contentParser = createParser(xContentBuilder); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); - assertEquals(FIELD_NAME, neuralQueryBuilder.getFieldName()); - assertEquals(QUERY_TEXT, neuralQueryBuilder.getQueryText()); - assertEquals(MODEL_ID, neuralQueryBuilder.getModelId()); - assertEquals(K, neuralQueryBuilder.getK()); + assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); + assertEquals(K, neuralQueryBuilder.k()); assertEquals(BOOST, neuralQueryBuilder.boost(), 0.0); assertEquals(QUERY_NAME, neuralQueryBuilder.queryName()); } - public void testFromXContent_invalid_multipleRootFields() throws IOException { + @SneakyThrows + public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() { /* { "VECTOR_FIELD": { @@ -124,7 +129,8 @@ public void testFromXContent_invalid_multipleRootFields() throws IOException { expectThrows(ParsingException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); } - public void testFromXContent_invalid_missingParameters() throws IOException { + @SneakyThrows + public void testFromXContent_whenBuildWithMissingParameters_thenFail() { /* { "VECTOR_FIELD": { @@ -138,7 +144,8 @@ public void testFromXContent_invalid_missingParameters() throws IOException { expectThrows(IllegalArgumentException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); } - public void testFromXContent_invalid_duplicateParameters() throws IOException { + @SneakyThrows + public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { /* { "VECTOR_FIELD": { @@ -168,7 +175,8 @@ public void testFromXContent_invalid_duplicateParameters() throws IOException { } @SuppressWarnings("unchecked") - public void testToXContent() throws IOException { + @SneakyThrows + public void testToXContent() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).modelId(MODEL_ID).queryText(QUERY_TEXT).k(K); XContentBuilder builder = XContentFactory.jsonBuilder(); @@ -197,7 +205,8 @@ public void testToXContent() throws IOException { assertEquals(K, secondInnerMap.get(K_FIELD.getPreferredName())); } - public void testStreams() throws IOException { + @SneakyThrows + public void testStreams() { NeuralQueryBuilder original = new NeuralQueryBuilder(); original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); @@ -227,100 +236,100 @@ public void testHashAndEquals() { int k1 = 1; int k2 = 2; - NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder().fieldName(fieldName1) + NeuralQueryBuilder neuralQueryBuilder_baseline = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) .k(k1) .boost(boost1) .queryName(queryName1); - // Identical to neuralQueryBuilder1 - NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to neuralQueryBuilder_baseline + NeuralQueryBuilder neuralQueryBuilder_baselineCopy = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) .k(k1) .boost(boost1) .queryName(queryName1); - // Identical to neuralQueryBuilder1 except default boost and query name - NeuralQueryBuilder neuralQueryBuilder3 = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to neuralQueryBuilder_baseline except default boost and query name + NeuralQueryBuilder neuralQueryBuilder_defaultBoostAndQueryName = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) .k(k1); - // Identical to neuralQueryBuilder1 except diff field name - NeuralQueryBuilder neuralQueryBuilder4 = new NeuralQueryBuilder().fieldName(fieldName2) + // Identical to neuralQueryBuilder_baseline except diff field name + NeuralQueryBuilder neuralQueryBuilder_diffFieldName = new NeuralQueryBuilder().fieldName(fieldName2) .queryText(queryText1) .modelId(modelId1) .k(k1) .boost(boost1) .queryName(queryName1); - // Identical to neuralQueryBuilder1 except diff query text - NeuralQueryBuilder neuralQueryBuilder5 = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to neuralQueryBuilder_baseline except diff query text + NeuralQueryBuilder neuralQueryBuilder_diffQueryText = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText2) .modelId(modelId1) .k(k1) .boost(boost1) .queryName(queryName1); - // Identical to neuralQueryBuilder1 except diff model ID - NeuralQueryBuilder neuralQueryBuilder6 = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to neuralQueryBuilder_baseline except diff model ID + NeuralQueryBuilder neuralQueryBuilder_diffModelId = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId2) .k(k1) .boost(boost1) .queryName(queryName1); - // Identical to neuralQueryBuilder1 except diff k - NeuralQueryBuilder neuralQueryBuilder7 = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to neuralQueryBuilder_baseline except diff k + NeuralQueryBuilder neuralQueryBuilder_diffK = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) .k(k2) .boost(boost1) .queryName(queryName1); - // Identical to neuralQueryBuilder1 except diff boost - NeuralQueryBuilder neuralQueryBuilder8 = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to neuralQueryBuilder_baseline except diff boost + NeuralQueryBuilder neuralQueryBuilder_diffBoost = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) .k(k1) .boost(boost2) .queryName(queryName1); - // Identical to neuralQueryBuilder1 except diff query name - NeuralQueryBuilder neuralQueryBuilder9 = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to neuralQueryBuilder_baseline except diff query name + NeuralQueryBuilder neuralQueryBuilder_diffQueryName = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) .k(k1) .boost(boost1) .queryName(queryName2); - assertEquals(neuralQueryBuilder1, neuralQueryBuilder1); - assertEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder1.hashCode()); + assertEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_baseline); + assertEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_baseline.hashCode()); - assertEquals(neuralQueryBuilder1, neuralQueryBuilder2); - assertEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder2.hashCode()); + assertEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_baselineCopy); + assertEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_baselineCopy.hashCode()); - assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder3); - assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder3.hashCode()); + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_defaultBoostAndQueryName); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_defaultBoostAndQueryName.hashCode()); - assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder4); - assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder4.hashCode()); + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffFieldName); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffFieldName.hashCode()); - assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder5); - assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder5.hashCode()); + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryText); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryText.hashCode()); - assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder6); - assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder6.hashCode()); + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffModelId); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffModelId.hashCode()); - assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder7); - assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder7.hashCode()); + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffK); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffK.hashCode()); - assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder8); - assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder8.hashCode()); + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffBoost); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffBoost.hashCode()); - assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder9); - assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder9.hashCode()); + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryName); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryName.hashCode()); } }