diff --git a/build.gradle b/build.gradle index 3b52deafb..f8671c3d6 100644 --- a/build.gradle +++ b/build.gradle @@ -59,6 +59,7 @@ opensearchplugin { classname "${projectPath}.${pathToPlugin}.${pluginClassName}" licenseFile rootProject.file('LICENSE') noticeFile rootProject.file('NOTICE') + extendedPlugins = ['opensearch-knn'] } dependencyLicenses.enabled = false @@ -134,7 +135,7 @@ def knnJarDirectory = "$buildDir/dependencies/opensearch-knn" dependencies { api "org.opensearch:opensearch:${opensearch_version}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}" - api fileTree(dir: knnJarDirectory, include: '*.jar') + compileOnly fileTree(dir: knnJarDirectory, include: '*.jar') api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" } diff --git a/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java b/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java new file mode 100644 index 000000000..2198f7e9e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.common; + +import java.util.List; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + +/** + * Utility class for working with vectors + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class VectorUtil { + + /** + * Converts a vector represented as a list to an array + * + * @param vectorAsList {@link List} of {@link Float}'s representing the vector + * @return array of floats produced from input list + */ + public static float[] vectorAsListToArray(List vectorAsList) { + float[] vector = new float[vectorAsList.size()]; + for (int i = 0; i < vectorAsList.size(); i++) { + vector[i] = vectorAsList.get(i); + } + return vector; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 5a6c4775e..8cd525a5e 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -35,6 +35,33 @@ public class MLCommonsClientAccessor { private static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); private final MachineLearningNodeClient mlClient; + /** + * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating + * point vector as a response. + * + * @param modelId {@link String} + * @param inputText {@link List} of {@link String} on which inference needs to happen + * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + */ + public void inferenceSentence( + @NonNull final String modelId, + @NonNull final String inputText, + @NonNull final ActionListener> listener + ) { + inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> { + if (response.size() != 1) { + listener.onFailure( + new IllegalStateException( + "Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]" + ) + ); + return; + } + + listener.onResponse(response.get(0)); + }, listener::onFailure)); + } + /** * Abstraction to call predict function of api of MLClient with default targetResponse filters. It uses the * custom model provided as modelId and run the {@link MLModelTaskType#TEXT_EMBEDDING}. The return will be sent diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 6a155227c..4b0a92b33 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -53,6 +53,7 @@ public Collection createComponents( ) { final MachineLearningNodeClient machineLearningNodeClient = new MachineLearningNodeClient(client); final MLCommonsClientAccessor clientAccessor = new MLCommonsClientAccessor(machineLearningNodeClient); + NeuralQueryBuilder.initialize(clientAccessor); return List.of(clientAccessor); } diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilder.java index 3b0bfd6a9..f283cf2f5 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilder.java @@ -5,8 +5,13 @@ package org.opensearch.neuralsearch.plugin.query; +import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; + import java.io.IOException; +import java.util.function.Supplier; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; @@ -16,6 +21,8 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.lucene.search.Query; +import org.apache.lucene.util.SetOnce; +import org.opensearch.action.ActionListener; import org.opensearch.common.ParseField; import org.opensearch.common.ParsingException; import org.opensearch.common.io.stream.StreamInput; @@ -24,7 +31,11 @@ import org.opensearch.common.xcontent.XContentParser; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import com.google.common.annotations.VisibleForTesting; @@ -39,6 +50,7 @@ @Setter @Accessors(chain = true, fluent = true) @NoArgsConstructor +@AllArgsConstructor public class NeuralQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "neural"; @@ -54,10 +66,20 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder private static final int DEFAULT_K = 10; + private static MLCommonsClientAccessor ML_CLIENT; + + public static void initialize(MLCommonsClientAccessor mlClient) { + NeuralQueryBuilder.ML_CLIENT = mlClient; + } + private String fieldName; private String queryText; private String modelId; private int k = DEFAULT_K; + @VisibleForTesting + @Getter(AccessLevel.PACKAGE) + @Setter(AccessLevel.PACKAGE) + private Supplier vectorSupplier; /** * Constructor from stream input @@ -113,15 +135,14 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws */ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOException { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); - if (parser.nextToken() != XContentParser.Token.START_OBJECT) { + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { throw new ParsingException(parser.getTokenLocation(), "Token must be START_OBJECT"); } parser.nextToken(); neuralQueryBuilder.fieldName(parser.currentName()); parser.nextToken(); parseQueryParams(parser, neuralQueryBuilder); - parser.nextToken(); - if (parser.currentToken() != XContentParser.Token.END_OBJECT) { + if (parser.nextToken() != XContentParser.Token.END_OBJECT) { throw new ParsingException( parser.getTokenLocation(), "[" @@ -172,10 +193,35 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n } } + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { + // When re-writing a QueryBuilder, if the QueryBuilder is not changed, doRewrite should return itself + // (see + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/QueryBuilder.java#L90-L98). + // Otherwise, it should return the modified copy (see rewrite logic + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L117. + // With the asynchronous call, on first rewrite, we create a new + // vector supplier that will get populated once the asynchronous call finishes and pass this supplier in to + // create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just + // return the current unmodified query builder. + if (vectorSupplier() != null) { + return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k()); + } + + SetOnce vectorSetOnce = new SetOnce<>(); + queryRewriteContext.registerAsyncAction( + ((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { + vectorSetOnce.set(vectorAsListToArray(floatList)); + actionListener.onResponse(null); + }, actionListener::onFailure))) + ); + return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get); + } + @Override protected Query doToQuery(QueryShardContext queryShardContext) { - // TODO Implement logic to build KNNQuery in this method - return null; + // All queries should be generated by the k-NN Query Builder + throw new UnsupportedOperationException("Query cannot be created by NeuralQueryBuilder directly"); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java b/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java new file mode 100644 index 000000000..1eba51e4f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.common; + +import java.util.Collections; +import java.util.List; + +import org.opensearch.test.OpenSearchTestCase; + +public class VectorUtilTests extends OpenSearchTestCase { + + public void testVectorAsListToArray() { + List vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f); + float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements); + + assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length); + for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) { + assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f); + } + + List vectorAsList_withNoElements = Collections.emptyList(); + float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements); + assertEquals(0, vectorAsArray_withNoElements.length); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 0b553175b..bfcc4eb86 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -32,6 +32,9 @@ public class MLCommonsClientAccessorTests extends OpenSearchTestCase { @Mock private ActionListener>> resultListener; + @Mock + private ActionListener> singleSentenceResultListener; + @Mock private MachineLearningNodeClient client; @@ -43,6 +46,22 @@ public void setup() { MockitoAnnotations.openMocks(this); } + public void testInferenceSentence_whenValidInput_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSentence(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST.get(0), singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + public void testInferenceSentences_whenValidInputThenSuccess() { final List> vectorList = new ArrayList<>(); vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY)); diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilderTests.java index eac7020bc..b4ff4a1e8 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/query/NeuralQueryBuilderTests.java @@ -5,6 +5,9 @@ package org.opensearch.neuralsearch.plugin.query; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; import static org.opensearch.neuralsearch.plugin.TestUtils.xContentBuilderToMap; @@ -14,16 +17,29 @@ import static org.opensearch.neuralsearch.plugin.query.NeuralQueryBuilder.QUERY_TEXT_FIELD; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.function.Supplier; import lombok.SneakyThrows; +import org.opensearch.action.ActionListener; +import org.opensearch.client.Client; import org.opensearch.common.ParsingException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.common.VectorUtil; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.test.OpenSearchTestCase; public class NeuralQueryBuilderTests extends OpenSearchTestCase { @@ -34,6 +50,7 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase { private static final int K = 10; private static final float BOOST = 1.8f; private static final String QUERY_NAME = "queryName"; + private static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[10]; @SneakyThrows public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { @@ -56,6 +73,7 @@ public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { .endObject(); XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); @@ -89,6 +107,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { .endObject(); XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser); assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); @@ -126,6 +145,7 @@ public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() { .endObject(); XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); expectThrows(ParsingException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); } @@ -141,6 +161,7 @@ public void testFromXContent_whenBuildWithMissingParameters_thenFail() { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject(FIELD_NAME).endObject().endObject(); XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); expectThrows(IllegalArgumentException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); } @@ -171,6 +192,7 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { .endObject(); XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); } @@ -332,4 +354,61 @@ public void testHashAndEquals() { assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryName); assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryName.hashCode()); } + + @SneakyThrows + public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).queryText(QUERY_TEXT).modelId(MODEL_ID).k(K); + List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(expectedVector); + return null; + }).when(mlCommonsClientAccessor).inferenceSentence(any(), any(), any()); + NeuralQueryBuilder.initialize(mlCommonsClientAccessor); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set vector supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + NeuralQueryBuilder queryBuilder = (NeuralQueryBuilder) neuralQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.vectorSupplier()); + assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); + assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f); + } + + public void testRewrite_whenVectorNull_thenReturnCopy() { + Supplier nullSupplier = () -> null; + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(nullSupplier); + QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); + assertEquals(neuralQueryBuilder, queryBuilder); + } + + public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER); + QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); + assertTrue(queryBuilder instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; + assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName()); + assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK()); + assertArrayEquals(TEST_VECTOR_SUPPLIER.get(), (float[]) knnQueryBuilder.vector(), 0.0f); + } }