Skip to content

Commit

Permalink
Add integration tests for neural query (#36)
Browse files Browse the repository at this point in the history
Adds a series of integration tests for neural query type. Adds shared
functionality to base class as well as a utility class. Increase test 
cluster heap to 1 GB.

Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 authored Oct 31, 2022
1 parent 9c2739c commit e30285b
Show file tree
Hide file tree
Showing 7 changed files with 656 additions and 21 deletions.
4 changes: 4 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ testClusters.integTest {
debugPort += 1
}
}

// Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due
// to ml-commons memory circuit breaker exception
jvmArgs("-Xms1g", "-Xmx1g")
}

// Remote Integration Tests
Expand Down
263 changes: 253 additions & 10 deletions src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@
package org.opensearch.neuralsearch.common;

import static org.apache.http.entity.ContentType.APPLICATION_JSON;
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.SneakyThrows;

import org.apache.commons.lang3.StringUtils;
import org.apache.http.Header;
Expand All @@ -28,14 +35,20 @@
import org.opensearch.client.Response;
import org.opensearch.client.RestClient;
import org.opensearch.client.WarningsHandler;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.test.rest.OpenSearchRestTestCase;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase;
import org.opensearch.rest.RestStatus;

import com.google.common.collect.ImmutableList;

public abstract class BaseNeuralSearchIT extends OpenSearchRestTestCase {
public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

private static final Locale LOCALE = Locale.ROOT;

Expand All @@ -45,7 +58,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchRestTestCase {

protected final ClassLoader classLoader = this.getClass().getClassLoader();

public String uploadModel(String requestBody) throws Exception {
protected String uploadModel(String requestBody) throws Exception {
Response uploadResponse = makeRequest(
client(),
"POST",
Expand Down Expand Up @@ -74,7 +87,7 @@ public String uploadModel(String requestBody) throws Exception {
return modelId;
}

public void loadModel(String modelId) throws IOException, InterruptedException {
protected void loadModel(String modelId) throws IOException, InterruptedException {
Response uploadResponse = makeRequest(
client(),
"POST",
Expand All @@ -100,6 +113,56 @@ public void loadModel(String modelId) throws IOException, InterruptedException {
}
}

/**
* Upload default model and load into the cluster
*
* @return modelID
*/
@SneakyThrows
protected String prepareModel() {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
String modelId = uploadModel(requestBody);
loadModel(modelId);
return modelId;
}

/**
* Execute model inference on the provided query text
*
* @param modelId id of model to run inference
* @param queryText text to be transformed to a model
* @return text embedding
*/
@SuppressWarnings("unchecked")
@SneakyThrows
protected float[] runInference(String modelId, String queryText) {
Response inferenceResponse = makeRequest(
client(),
"POST",
String.format(LOCALE, "/_plugins/_ml/_predict/text_embedding/%s", modelId),
null,
toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"],\"target_response\": [\"sentence_embedding\"]}", queryText)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
);

Map<String, Object> inferenceResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
EntityUtils.toString(inferenceResponse.getEntity()),
false
);

Object inference_results = inferenceResJson.get("inference_results");
assertTrue(inference_results instanceof List);
List<Object> inferenceResultsAsMap = (List<Object>) inference_results;
assertEquals(1, inferenceResultsAsMap.size());
Map<String, Object> result = (Map<String, Object>) inferenceResultsAsMap.get(0);
List<Object> output = (List<Object>) result.get("output");
assertEquals(1, output.size());
Map<String, Object> map = (Map<String, Object>) output.get(0);
List<Float> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
return vectorAsListToArray(data);
}

protected void createIndexWithConfiguration(String indexName, String indexConfiguration, String pipelineName) throws Exception {
if (StringUtils.isNotBlank(pipelineName)) {
indexConfiguration = String.format(LOCALE, indexConfiguration, pipelineName);
Expand All @@ -121,7 +184,7 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig
assertEquals(indexName, node.get("index").toString());
}

public void createPipelineProcessor(String modelId, String pipelineName) throws Exception {
protected void createPipelineProcessor(String modelId, String pipelineName) throws Exception {
Response pipelineCreateResponse = makeRequest(
client(),
"PUT",
Expand All @@ -144,7 +207,155 @@ public void createPipelineProcessor(String modelId, String pipelineName) throws
assertEquals("true", node.get("acknowledged").toString());
}

public Map<String, Object> getTaskQueryResponse(String taskId) throws IOException {
/**
* Get the number of documents in a particular index
*
* @param indexName name of index
* @return number of documents indexed to that index
*/
@SneakyThrows
protected int getDocCount(String indexName) {
Request request = new Request("GET", "/" + indexName + "/_count");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
String responseBody = EntityUtils.toString(response.getEntity());
Map<String, Object> responseMap = createParser(XContentType.JSON.xContent(), responseBody).map();
return (Integer) responseMap.get("count");
}

/**
* Execute a search request initialized from a neural query builder
*
* @param index Index to search against
* @param queryBuilder queryBuilder to produce source of query
* @param resultSize number of results to return in the search
* @return Search results represented as a map
*/
protected Map<String, Object> search(String index, QueryBuilder queryBuilder, int resultSize) {
return search(index, queryBuilder, null, resultSize);
}

/**
* Execute a search request initialized from a neural query builder that can add a rescore query to the request
*
* @param index Index to search against
* @param queryBuilder queryBuilder to produce source of query
* @param rescorer used for rescorer query builder
* @param resultSize number of results to return in the search
* @return Search results represented as a map
*/
@SneakyThrows
protected Map<String, Object> search(String index, QueryBuilder queryBuilder, QueryBuilder rescorer, int resultSize) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query");
queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);

if (rescorer != null) {
builder.startObject("rescore").startObject("query").field("query_weight", 0.0f).field("rescore_query");
rescorer.toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject().endObject();
}

builder.endObject();

Request request = new Request("POST", "/" + index + "/_search");
request.addParameter("size", Integer.toString(resultSize));
request.setJsonEntity(Strings.toString(builder));

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

String responseBody = EntityUtils.toString(response.getEntity());

return XContentHelper.convertToMap(XContentFactory.xContent(XContentType.JSON), responseBody, false);
}

/**
* Add a set of knn vectors
*
* @param index Name of the index
* @param docId ID of document to be added
* @param vectorFieldNames List of vectir fields to be added
* @param vectors List of vectors corresponding to those fields
*/
protected void addKnnDoc(String index, String docId, List<String> vectorFieldNames, List<Object[]> vectors) {
addKnnDoc(index, docId, vectorFieldNames, vectors, Collections.emptyList(), Collections.emptyList());
}

/**
* Add a set of knn vectors and text to an index
*
* @param index Name of the index
* @param docId ID of document to be added
* @param vectorFieldNames List of vectir fields to be added
* @param vectors List of vectors corresponding to those fields
* @param textFieldNames List of text fields to be added
* @param texts List of text corresponding to those fields
*/
@SneakyThrows
protected void addKnnDoc(
String index,
String docId,
List<String> vectorFieldNames,
List<Object[]> vectors,
List<String> textFieldNames,
List<String> texts
) {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
for (int i = 0; i < vectorFieldNames.size(); i++) {
builder.field(vectorFieldNames.get(i), vectors.get(i));
}

for (int i = 0; i < textFieldNames.size(); i++) {
builder.field(textFieldNames.get(i), texts.get(i));
}
builder.endObject();

request.setJsonEntity(Strings.toString(builder));
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

/**
* Parse the first returned hit from a search response as a map
*
* @param searchResponseAsMap Complete search response as a map
* @return Map of first internal hit from the search
*/
@SuppressWarnings("unchecked")
protected Map<String, Object> getFirstInnerHit(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hits1map = (Map<String, Object>) searchResponseAsMap.get("hits");
List<Object> hits2List = (List<Object>) hits1map.get("hits");
assertTrue(hits2List.size() > 0);
return (Map<String, Object>) hits2List.get(0);
}

/**
* Create a k-NN index from a list of KNNFieldConfigs
*
* @param indexName of index to be created
* @param knnFieldConfigs list of configs specifying field
*/
@SneakyThrows
protected void prepareKnnIndex(String indexName, List<KNNFieldConfig> knnFieldConfigs) {
createIndexWithConfiguration(indexName, buildIndexConfiguration(knnFieldConfigs), "");
}

/**
* Computes the expected distance between an indexVector and query text without using the neural query type.
*
* @param modelId ID of model to run inference
* @param indexVector vector to compute score against
* @param spaceType Space to measure distance
* @param queryText Text to produce query vector from
* @return Expected OpenSearch score for this indexVector
*/
protected float computeExpectedScore(String modelId, float[] indexVector, SpaceType spaceType, String queryText) {
float[] queryVector = runInference(modelId, queryText);
return spaceType.getVectorSimilarityFunction().compare(queryVector, indexVector);
}

protected Map<String, Object> getTaskQueryResponse(String taskId) throws IOException {
Response taskQueryResponse = makeRequest(
client(),
"GET",
Expand All @@ -160,12 +371,37 @@ public Map<String, Object> getTaskQueryResponse(String taskId) throws IOExceptio
);
}

public boolean checkComplete(Map<String, Object> node) {
protected boolean checkComplete(Map<String, Object> node) {
Predicate<Map<String, Object>> predicate = x -> node.get("error") != null || "COMPLETED".equals(String.valueOf(node.get("state")));
return predicate.test(node);
}

public static Response makeRequest(
@SneakyThrows
private String buildIndexConfiguration(List<KNNFieldConfig> knnFieldConfigs) {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("settings")
.field("number_of_shards", 3)
.field("index.knn", true)
.endObject()
.startObject("mappings")
.startObject("properties");

for (KNNFieldConfig knnFieldConfig : knnFieldConfigs) {
xContentBuilder.startObject(knnFieldConfig.getName())
.field("type", "knn_vector")
.field("dimension", Integer.toString(knnFieldConfig.getDimension()))
.startObject("method")
.field("engine", "lucene")
.field("space_type", knnFieldConfig.getSpaceType().getValue())
.field("name", "hnsw")
.endObject()
.endObject();
}
return Strings.toString(xContentBuilder.endObject().endObject().endObject());
}

protected static Response makeRequest(
RestClient client,
String method,
String endpoint,
Expand All @@ -176,7 +412,7 @@ public static Response makeRequest(
return makeRequest(client, method, endpoint, params, entity, headers, false);
}

public static Response makeRequest(
protected static Response makeRequest(
RestClient client,
String method,
String endpoint,
Expand All @@ -203,8 +439,15 @@ public static Response makeRequest(
return client.performRequest(request);
}

public static HttpEntity toHttpEntity(String jsonString) {
protected static HttpEntity toHttpEntity(String jsonString) {
return new StringEntity(jsonString, APPLICATION_JSON);
}

@AllArgsConstructor
@Getter
protected static class KNNFieldConfig {
private final String name;
private final Integer dimension;
private final SpaceType spaceType;
}
}
Loading

0 comments on commit e30285b

Please sign in to comment.