Skip to content

Commit

Permalink
Integrate model inference to build query (#20)
Browse files Browse the repository at this point in the history
Integrates ml-commons model inference capabilities to transform
NeuralQueryBuilder into a KNNQueryBuilder. Minor changes to parsing
logic and build.gradle to fix bugs. Minor enhancement to
MLCommonsCLientAccessor to add single sentence inference.

Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 authored Oct 19, 2022
1 parent 393141d commit 272d803
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 6 deletions.
3 changes: 2 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ opensearchplugin {
classname "${projectPath}.${pathToPlugin}.${pluginClassName}"
licenseFile rootProject.file('LICENSE')
noticeFile rootProject.file('NOTICE')
extendedPlugins = ['opensearch-knn']
}

dependencyLicenses.enabled = false
Expand Down Expand Up @@ -134,7 +135,7 @@ def knnJarDirectory = "$buildDir/dependencies/opensearch-knn"
dependencies {
api "org.opensearch:opensearch:${opensearch_version}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}"
api fileTree(dir: knnJarDirectory, include: '*.jar')
compileOnly fileTree(dir: knnJarDirectory, include: '*.jar')
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
}

Expand Down
32 changes: 32 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.common;

import java.util.List;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;

/**
* Utility class for working with vectors
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class VectorUtil {

/**
* Converts a vector represented as a list to an array
*
* @param vectorAsList {@link List} of {@link Float}'s representing the vector
* @return array of floats produced from input list
*/
public static float[] vectorAsListToArray(List<Float> vectorAsList) {
float[] vector = new float[vectorAsList.size()];
for (int i = 0; i < vectorAsList.size(); i++) {
vector[i] = vectorAsList.get(i);
}
return vector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@ public class MLCommonsClientAccessor {
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
private final MachineLearningNodeClient mlClient;

/**
* Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating
* point vector as a response.
*
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out
*/
public void inferenceSentence(
@NonNull final String modelId,
@NonNull final String inputText,
@NonNull final ActionListener<List<Float>> listener
) {
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> {
if (response.size() != 1) {
listener.onFailure(
new IllegalStateException(
"Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"
)
);
return;
}

listener.onResponse(response.get(0));
}, listener::onFailure));
}

/**
* Abstraction to call predict function of api of MLClient with default targetResponse filters. It uses the
* custom model provided as modelId and run the {@link MLModelTaskType#TEXT_EMBEDDING}. The return will be sent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public Collection<Object> createComponents(
) {
final MachineLearningNodeClient machineLearningNodeClient = new MachineLearningNodeClient(client);
final MLCommonsClientAccessor clientAccessor = new MLCommonsClientAccessor(machineLearningNodeClient);
NeuralQueryBuilder.initialize(clientAccessor);
return List.of(clientAccessor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

package org.opensearch.neuralsearch.plugin.query;

import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;

import java.io.IOException;
import java.util.function.Supplier;

import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
Expand All @@ -16,6 +21,8 @@
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.SetOnce;
import org.opensearch.action.ActionListener;
import org.opensearch.common.ParseField;
import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput;
Expand All @@ -24,7 +31,11 @@
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import com.google.common.annotations.VisibleForTesting;

Expand All @@ -39,6 +50,7 @@
@Setter
@Accessors(chain = true, fluent = true)
@NoArgsConstructor
@AllArgsConstructor
public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder> {

public static final String NAME = "neural";
Expand All @@ -54,10 +66,20 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>

private static final int DEFAULT_K = 10;

private static MLCommonsClientAccessor ML_CLIENT;

public static void initialize(MLCommonsClientAccessor mlClient) {
NeuralQueryBuilder.ML_CLIENT = mlClient;
}

private String fieldName;
private String queryText;
private String modelId;
private int k = DEFAULT_K;
@VisibleForTesting
@Getter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;

/**
* Constructor from stream input
Expand Down Expand Up @@ -113,15 +135,14 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
*/
public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOException {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
if (parser.nextToken() != XContentParser.Token.START_OBJECT) {
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "Token must be START_OBJECT");
}
parser.nextToken();
neuralQueryBuilder.fieldName(parser.currentName());
parser.nextToken();
parseQueryParams(parser, neuralQueryBuilder);
parser.nextToken();
if (parser.currentToken() != XContentParser.Token.END_OBJECT) {
if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
throw new ParsingException(
parser.getTokenLocation(),
"["
Expand Down Expand Up @@ -172,10 +193,35 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
}
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
// When re-writing a QueryBuilder, if the QueryBuilder is not changed, doRewrite should return itself
// (see
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/QueryBuilder.java#L90-L98).
// Otherwise, it should return the modified copy (see rewrite logic
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L117.
// With the asynchronous call, on first rewrite, we create a new
// vector supplier that will get populated once the asynchronous call finishes and pass this supplier in to
// create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just
// return the current unmodified query builder.
if (vectorSupplier() != null) {
return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k());
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)))
);
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get);
}

@Override
protected Query doToQuery(QueryShardContext queryShardContext) {
// TODO Implement logic to build KNNQuery in this method
return null;
// All queries should be generated by the k-NN Query Builder
throw new UnsupportedOperationException("Query cannot be created by NeuralQueryBuilder directly");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.common;

import java.util.Collections;
import java.util.List;

import org.opensearch.test.OpenSearchTestCase;

public class VectorUtilTests extends OpenSearchTestCase {

public void testVectorAsListToArray() {
List<Float> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements);

assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length);
for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) {
assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f);
}

List<Float> vectorAsList_withNoElements = Collections.emptyList();
float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements);
assertEquals(0, vectorAsArray_withNoElements.length);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public class MLCommonsClientAccessorTests extends OpenSearchTestCase {
@Mock
private ActionListener<List<List<Float>>> resultListener;

@Mock
private ActionListener<List<Float>> singleSentenceResultListener;

@Mock
private MachineLearningNodeClient client;

Expand All @@ -43,6 +46,22 @@ public void setup() {
MockitoAnnotations.openMocks(this);
}

public void testInferenceSentence_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSentence(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST.get(0), singleSentenceResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceSentences_whenValidInputThenSuccess() {
final List<List<Float>> vectorList = new ArrayList<>();
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand Down
Loading

0 comments on commit 272d803

Please sign in to comment.