Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text embedding processor to neural search #18

Merged
merged 5 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
* Learn more about Gradle by exploring our samples at https://docs.gradle.org/7.5.1/samples
* This project uses @Incubating APIs which are subject to change.
*/

import org.opensearch.gradle.test.RestIntegTestTask

import java.util.concurrent.Callable

apply plugin: 'java'
Expand Down Expand Up @@ -137,6 +139,7 @@ dependencies {
zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}"
compileOnly fileTree(dir: knnJarDirectory, include: '*.jar')
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.10'
}

// From maven, we can get the k-NN plugin as a zip. In order to add the jar to the classpath, we need to unzip the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelTaskType;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
Expand Down Expand Up @@ -99,23 +102,54 @@ public void inferenceSentences(
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
final MLInput mlInput = new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset, MLModelTaskType.TEXT_EMBEDDING);
final List<List<Float>> vector = new ArrayList<>();

MLInput mlInput = createMLInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
}
}
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
listener.onResponse(vector);
}, listener::onFailure));
}

/**
* Abstraction to call predict function of api of MLClient with provided targetResponseFilters. It uses the
* custom model provided as modelId and run the {@link MLModelTaskType#TEXT_EMBEDDING}. The return will be sent
* using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of
* inputText. We are not making this function generic enough to take any function or TaskType as currently we need
* to run only TextEmbedding tasks only. Please note this method is a blocking method, use this only when the processing
* needs block waiting for response, otherwise please use {@link #inferenceSentences(String, List, ActionListener)}
* instead.
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen.
* @return {@link List} of {@link List} of {@link String} represents the text embedding vector result.
* @throws ExecutionException If the underlying task failed, this exception will be thrown in the future.get().
* @throws InterruptedException If the thread is interrupted, this will be thrown.
*/
public List<List<Float>> inferenceSentences(@NonNull final String modelId, @NonNull final List<String> inputText)
throws ExecutionException, InterruptedException {
final MLInput mlInput = createMLInput(TARGET_RESPONSE_FILTERS, inputText);
final ActionFuture<MLOutput> outputActionFuture = mlClient.predict(modelId, mlInput);
final List<List<Float>> vector = buildVectorFromResponse(outputActionFuture.get());
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
return vector;
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset, MLModelTaskType.TEXT_EMBEDDING);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
}
}
return vector;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import org.opensearch.action.ActionRequest;
Expand All @@ -19,12 +20,16 @@
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.plugin.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.transport.MLPredictAction;
import org.opensearch.neuralsearch.transport.MLPredictTransportAction;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
Expand All @@ -35,7 +40,9 @@
/**
* Neural Search plugin class
*/
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin {
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin {

private MLCommonsClientAccessor clientAccessor;

@Override
public Collection<Object> createComponents(
Expand All @@ -51,8 +58,6 @@ public Collection<Object> createComponents(
final IndexNameExpressionResolver indexNameExpressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
final MachineLearningNodeClient machineLearningNodeClient = new MachineLearningNodeClient(client);
final MLCommonsClientAccessor clientAccessor = new MLCommonsClientAccessor(machineLearningNodeClient);
NeuralQueryBuilder.initialize(clientAccessor);
return List.of(clientAccessor);
}
Expand All @@ -72,4 +77,10 @@ public List<QuerySpec<?>> getQueries() {
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent)
);
}

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
}
}
Loading