Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 committed Oct 11, 2022
1 parent 54ecd48 commit 365e0a7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 104 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ allprojects {
}

repositories {
mavenLocal()
maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" }
mavenCentral()
maven { url "https://plugins.gradle.org/m2/" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void inferenceSentences(
) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
final MLInput mlInput = new MLInput(FunctionName.CUSTOM, null, inputDataset, MLModelTaskType.TEXT_EMBEDDING);
final MLInput mlInput = new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset, MLModelTaskType.TEXT_EMBEDDING);
final List<List<Float>> vector = new ArrayList<>();

mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;
import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang.builder.EqualsBuilder;
Expand All @@ -24,6 +26,8 @@
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryShardContext;

import com.google.common.annotations.VisibleForTesting;

/**
* NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a wrapper around a
* k-NN vector query. It uses a ML language model to produce a dense vector from a query string that is then used as
Expand All @@ -32,65 +36,28 @@

@Log4j2
@Getter
@Setter
@Accessors(chain = true, fluent = true)
@NoArgsConstructor
public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder> {

public static final String NAME = "neural";

@VisibleForTesting
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");

@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");

@VisibleForTesting
static final ParseField K_FIELD = new ParseField("k");

private static int DEFAULT_K = 10;

private String fieldName;
private String queryText;
private String modelId;
private int k;

/**
* Set the fieldName this query will be executed against
*
* @param fieldName name of k-NN vector field that query will be executed against
* @return this
*/
public NeuralQueryBuilder fieldName(String fieldName) {
this.fieldName = fieldName;
return this;
}

/**
* Set the queryText that will be translated into the dense query vector used for k-NN search.
*
* @param queryText Text of a query that should be translated to a dense vector
* @return this
*/
public NeuralQueryBuilder queryText(String queryText) {
this.queryText = queryText;
return this;
}

/**
* Set the modelId that should produce the dense query vector
*
* @param modelId ID of model to produce query vector
* @return this
*/
public NeuralQueryBuilder modelId(String modelId) {
this.modelId = modelId;
return this;
}

/**
* Set the number of neighbors that should be retrieved during k-NN search
*
* @param k number of neighbors to be retrieved in k-NN query
* @return this
*/
public NeuralQueryBuilder k(int k) {
this.k = k;
return this;
}
private int k = DEFAULT_K;

/**
* Constructor from stream input
Expand All @@ -103,15 +70,15 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
this.k = in.readInt();
this.k = in.readVInt();
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeString(this.queryText);
out.writeString(this.modelId);
out.writeInt(this.k);
out.writeVInt(this.k);
}

@Override
Expand Down Expand Up @@ -157,11 +124,9 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx
if (parser.currentToken() != XContentParser.Token.END_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "Token must be END_OBJECT");
}

requireValue(neuralQueryBuilder.getQueryText(), "Query text must be provided for neural query");
requireValue(neuralQueryBuilder.getFieldName(), "Field name must be provided for neural query");
requireValue(neuralQueryBuilder.getModelId(), "Model ID must be provided for neural query");
requireValue(neuralQueryBuilder.getK(), "K must be provided for neural query");
requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query");
requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query");
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");

return neuralQueryBuilder;
}
Expand Down Expand Up @@ -199,7 +164,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
}

@Override
protected Query doToQuery(QueryShardContext queryShardContext) throws IOException {
protected Query doToQuery(QueryShardContext queryShardContext) {
// TODO Implement logic to build KNNQuery in this method
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.io.IOException;
import java.util.Map;

import lombok.SneakyThrows;

import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.ToXContent;
Expand All @@ -33,7 +35,8 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase {
private static final float BOOST = 1.8f;
private static final String QUERY_NAME = "queryName";

public void testFromXContent_valid_withDefaults() throws IOException {
@SneakyThrows
public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() {
/*
{
"VECTOR_FIELD": {
Expand All @@ -55,13 +58,14 @@ public void testFromXContent_valid_withDefaults() throws IOException {
XContentParser contentParser = createParser(xContentBuilder);
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser);

assertEquals(FIELD_NAME, neuralQueryBuilder.getFieldName());
assertEquals(QUERY_TEXT, neuralQueryBuilder.getQueryText());
assertEquals(MODEL_ID, neuralQueryBuilder.getModelId());
assertEquals(K, neuralQueryBuilder.getK());
assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
assertEquals(K, neuralQueryBuilder.k());
}

public void testFromXContent_valid_withOptionals() throws IOException {
@SneakyThrows
public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
/*
{
"VECTOR_FIELD": {
Expand All @@ -87,15 +91,16 @@ public void testFromXContent_valid_withOptionals() throws IOException {
XContentParser contentParser = createParser(xContentBuilder);
NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.fromXContent(contentParser);

assertEquals(FIELD_NAME, neuralQueryBuilder.getFieldName());
assertEquals(QUERY_TEXT, neuralQueryBuilder.getQueryText());
assertEquals(MODEL_ID, neuralQueryBuilder.getModelId());
assertEquals(K, neuralQueryBuilder.getK());
assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText());
assertEquals(MODEL_ID, neuralQueryBuilder.modelId());
assertEquals(K, neuralQueryBuilder.k());
assertEquals(BOOST, neuralQueryBuilder.boost(), 0.0);
assertEquals(QUERY_NAME, neuralQueryBuilder.queryName());
}

public void testFromXContent_invalid_multipleRootFields() throws IOException {
@SneakyThrows
public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {
/*
{
"VECTOR_FIELD": {
Expand Down Expand Up @@ -124,7 +129,8 @@ public void testFromXContent_invalid_multipleRootFields() throws IOException {
expectThrows(ParsingException.class, () -> NeuralQueryBuilder.fromXContent(contentParser));
}

public void testFromXContent_invalid_missingParameters() throws IOException {
@SneakyThrows
public void testFromXContent_whenBuildWithMissingParameters_thenFail() {
/*
{
"VECTOR_FIELD": {
Expand All @@ -138,7 +144,8 @@ public void testFromXContent_invalid_missingParameters() throws IOException {
expectThrows(IllegalArgumentException.class, () -> NeuralQueryBuilder.fromXContent(contentParser));
}

public void testFromXContent_invalid_duplicateParameters() throws IOException {
@SneakyThrows
public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {
/*
{
"VECTOR_FIELD": {
Expand Down Expand Up @@ -168,7 +175,8 @@ public void testFromXContent_invalid_duplicateParameters() throws IOException {
}

@SuppressWarnings("unchecked")
public void testToXContent() throws IOException {
@SneakyThrows
public void testToXContent() {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).modelId(MODEL_ID).queryText(QUERY_TEXT).k(K);

XContentBuilder builder = XContentFactory.jsonBuilder();
Expand Down Expand Up @@ -197,7 +205,8 @@ public void testToXContent() throws IOException {
assertEquals(K, secondInnerMap.get(K_FIELD.getPreferredName()));
}

public void testStreams() throws IOException {
@SneakyThrows
public void testStreams() {
NeuralQueryBuilder original = new NeuralQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryText(QUERY_TEXT);
Expand Down Expand Up @@ -227,100 +236,100 @@ public void testHashAndEquals() {
int k1 = 1;
int k2 = 2;

NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder().fieldName(fieldName1)
NeuralQueryBuilder neuralQueryBuilder_baseline = new NeuralQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.k(k1)
.boost(boost1)
.queryName(queryName1);

// Identical to neuralQueryBuilder1
NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder().fieldName(fieldName1)
// Identical to neuralQueryBuilder_baseline
NeuralQueryBuilder neuralQueryBuilder_baselineCopy = new NeuralQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.k(k1)
.boost(boost1)
.queryName(queryName1);

// Identical to neuralQueryBuilder1 except default boost and query name
NeuralQueryBuilder neuralQueryBuilder3 = new NeuralQueryBuilder().fieldName(fieldName1)
// Identical to neuralQueryBuilder_baseline except default boost and query name
NeuralQueryBuilder neuralQueryBuilder_defaultBoostAndQueryName = new NeuralQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.k(k1);

// Identical to neuralQueryBuilder1 except diff field name
NeuralQueryBuilder neuralQueryBuilder4 = new NeuralQueryBuilder().fieldName(fieldName2)
// Identical to neuralQueryBuilder_baseline except diff field name
NeuralQueryBuilder neuralQueryBuilder_diffFieldName = new NeuralQueryBuilder().fieldName(fieldName2)
.queryText(queryText1)
.modelId(modelId1)
.k(k1)
.boost(boost1)
.queryName(queryName1);

// Identical to neuralQueryBuilder1 except diff query text
NeuralQueryBuilder neuralQueryBuilder5 = new NeuralQueryBuilder().fieldName(fieldName1)
// Identical to neuralQueryBuilder_baseline except diff query text
NeuralQueryBuilder neuralQueryBuilder_diffQueryText = new NeuralQueryBuilder().fieldName(fieldName1)
.queryText(queryText2)
.modelId(modelId1)
.k(k1)
.boost(boost1)
.queryName(queryName1);

// Identical to neuralQueryBuilder1 except diff model ID
NeuralQueryBuilder neuralQueryBuilder6 = new NeuralQueryBuilder().fieldName(fieldName1)
// Identical to neuralQueryBuilder_baseline except diff model ID
NeuralQueryBuilder neuralQueryBuilder_diffModelId = new NeuralQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId2)
.k(k1)
.boost(boost1)
.queryName(queryName1);

// Identical to neuralQueryBuilder1 except diff k
NeuralQueryBuilder neuralQueryBuilder7 = new NeuralQueryBuilder().fieldName(fieldName1)
// Identical to neuralQueryBuilder_baseline except diff k
NeuralQueryBuilder neuralQueryBuilder_diffK = new NeuralQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.k(k2)
.boost(boost1)
.queryName(queryName1);

// Identical to neuralQueryBuilder1 except diff boost
NeuralQueryBuilder neuralQueryBuilder8 = new NeuralQueryBuilder().fieldName(fieldName1)
// Identical to neuralQueryBuilder_baseline except diff boost
NeuralQueryBuilder neuralQueryBuilder_diffBoost = new NeuralQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.k(k1)
.boost(boost2)
.queryName(queryName1);

// Identical to neuralQueryBuilder1 except diff query name
NeuralQueryBuilder neuralQueryBuilder9 = new NeuralQueryBuilder().fieldName(fieldName1)
// Identical to neuralQueryBuilder_baseline except diff query name
NeuralQueryBuilder neuralQueryBuilder_diffQueryName = new NeuralQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.k(k1)
.boost(boost1)
.queryName(queryName2);

assertEquals(neuralQueryBuilder1, neuralQueryBuilder1);
assertEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder1.hashCode());
assertEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_baseline);
assertEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_baseline.hashCode());

assertEquals(neuralQueryBuilder1, neuralQueryBuilder2);
assertEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder2.hashCode());
assertEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_baselineCopy);
assertEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_baselineCopy.hashCode());

assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder3);
assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder3.hashCode());
assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_defaultBoostAndQueryName);
assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_defaultBoostAndQueryName.hashCode());

assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder4);
assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder4.hashCode());
assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffFieldName);
assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffFieldName.hashCode());

assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder5);
assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder5.hashCode());
assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryText);
assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryText.hashCode());

assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder6);
assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder6.hashCode());
assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffModelId);
assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffModelId.hashCode());

assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder7);
assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder7.hashCode());
assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffK);
assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffK.hashCode());

assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder8);
assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder8.hashCode());
assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffBoost);
assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffBoost.hashCode());

assertNotEquals(neuralQueryBuilder1, neuralQueryBuilder9);
assertNotEquals(neuralQueryBuilder1.hashCode(), neuralQueryBuilder9.hashCode());
assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryName);
assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryName.hashCode());
}
}

0 comments on commit 365e0a7

Please sign in to comment.