Skip to content

Support cosine similarity in kNN search #79500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ setup:
type: dense_vector
dims: 5
index: true
similarity: dot_product

similarity: cosine
- do:
index:
index: test-index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
type: dense_vector
dims: 3
index: true
similarity: dot_product
similarity: cosine

- do:
bulk:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.MappingParser;
import org.elasticsearch.index.mapper.PerFieldKnnVectorsFormatFieldMapper;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser.Token;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
Expand All @@ -31,14 +27,20 @@
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MappingParser;
import org.elasticsearch.index.mapper.PerFieldKnnVectorsFormatFieldMapper;
import org.elasticsearch.index.mapper.SimpleMappedFieldType;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser.Token;
import org.elasticsearch.xpack.vectors.query.KnnVectorFieldExistsQuery;
import org.elasticsearch.xpack.vectors.query.KnnVectorQueryBuilder;
import org.elasticsearch.xpack.vectors.query.VectorIndexFieldData;

import java.io.IOException;
Expand Down Expand Up @@ -107,7 +109,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
return new DenseVectorFieldMapper(
name,
new DenseVectorFieldType(context.buildFullName(name), indexVersionCreated,
dims.getValue(), indexed.getValue(), meta.getValue()),
dims.getValue(), indexed.getValue(), similarity.getValue(), meta.getValue()),
dims.getValue(),
indexed.getValue(),
similarity.getValue(),
Expand All @@ -120,6 +122,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {

enum VectorSimilarity {
l2_norm(VectorSimilarityFunction.EUCLIDEAN),
cosine(VectorSimilarityFunction.COSINE),
dot_product(VectorSimilarityFunction.DOT_PRODUCT);

public final VectorSimilarityFunction function;
Expand Down Expand Up @@ -195,19 +198,18 @@ public String toString() {
public static final class DenseVectorFieldType extends SimpleMappedFieldType {
private final int dims;
private final boolean indexed;
private final VectorSimilarity similarity;
private final Version indexVersionCreated;

public DenseVectorFieldType(String name, Version indexVersionCreated, int dims, boolean indexed, Map<String, String> meta) {
public DenseVectorFieldType(String name, Version indexVersionCreated, int dims,boolean indexed,
VectorSimilarity similarity, Map<String, String> meta) {
super(name, indexed, false, indexed == false, TextSearchInfo.NONE, meta);
this.dims = dims;
this.indexed = indexed;
this.similarity = similarity;
this.indexVersionCreated = indexVersionCreated;
}

public int dims() {
return dims;
}

@Override
public String typeName() {
return CONTENT_TYPE;
Expand Down Expand Up @@ -257,6 +259,48 @@ public Query termQuery(Object value, SearchExecutionContext context) {
throw new IllegalArgumentException(
"Field [" + name() + "] of type [" + typeName() + "] doesn't support queries");
}

public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands) {
if (isSearchable() == false) {
throw new IllegalArgumentException("[" + KnnVectorQueryBuilder.NAME + "] " +
"queries are not supported if [index] is disabled");
}

if (queryVector.length != dims) {
throw new IllegalArgumentException("the query vector has a different dimension [" + queryVector.length + "] "
+ "than the index vectors [" + dims + "]");
}

if (similarity == VectorSimilarity.dot_product) {
double squaredMagnitude = 0.0;
for (float e : queryVector) {
squaredMagnitude += e * e;
}
checkVectorMagnitude(queryVector, squaredMagnitude);
}
return new KnnVectorQuery(name(), queryVector, numCands);
}

private void checkVectorMagnitude(float[] vector, double squaredMagnitude) {
if (Math.abs(squaredMagnitude - 1.0f) > 1e-4) {
// Include the first five elements of the invalid vector in the error message
StringBuilder sb = new StringBuilder("The [" + VectorSimilarity.dot_product.name() + "] similarity can " +
"only be used with unit-length vectors. Preview of invalid vector: ");
sb.append("[");
for (int i = 0; i < Math.min(5, vector.length); i++) {
if (i > 0) {
sb.append(", ");
}
sb.append(vector[i]);
}
if (vector.length >= 5) {
sb.append(", ...");
}
sb.append("]");

throw new IllegalArgumentException(sb.toString());
}
}
}

private final int dims;
Expand Down Expand Up @@ -301,13 +345,20 @@ public void parse(DocumentParserContext context) throws IOException {

private Field parseKnnVector(DocumentParserContext context) throws IOException {
float[] vector = new float[dims];
double squaredMagnitude = 0.0;
int index = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
checkDimensionExceeded(index, context);
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
vector[index++] = context.parser().floatValue(true);

float value = context.parser().floatValue(true);
vector[index++] = value;
squaredMagnitude += value * value;
}
checkDimensionMatches(index, context);
if (similarity == VectorSimilarity.dot_product) {
fieldType().checkVectorMagnitude(vector, squaredMagnitude);
}
return new KnnVectorField(fieldType().name(), vector, similarity.function);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package org.elasticsearch.xpack.vectors.query;

import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -92,14 +91,7 @@ protected Query doToQuery(SearchExecutionContext context) {
}

DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
if (queryVector.length != vectorFieldType.dims()) {
throw new IllegalArgumentException("the query vector has a different dimension [" + queryVector.length + "] "
+ "than the index vectors [" + vectorFieldType.dims() + "]");
}
if (vectorFieldType.isSearchable() == false) {
throw new IllegalArgumentException("[" + NAME + "] queries are not supported if [index] is disabled");
}
return new KnnVectorQuery(fieldType.name(), queryVector, numCands);
return vectorFieldType.createKnnQuery(queryVector, numCands);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ protected void minimalMapping(XContentBuilder b) throws IOException {

@Override
protected Object getSampleValueForDocument() {
return List.of(1, 2, 3, 4);
return List.of(0.5, 0.5, 0.5, 0.5);
}

@Override
Expand Down Expand Up @@ -204,7 +204,7 @@ public void testIndexedVector() throws Exception {
.field("index", true)
.field("similarity", similarity.name())));

float[] vector = {-12.1f, 100.7f, -4};
float[] vector = {-0.5f, 0.5f, 0.7071f};
ParsedDocument doc1 = mapper.parse(source(b -> b.array("field", vector)));

IndexableField[] fields = doc1.rootDoc().getFields("field");
Expand All @@ -220,6 +220,30 @@ public void testIndexedVector() throws Exception {
assertEquals(similarity.function, vectorField.fieldType().vectorSimilarityFunction());
}

public void testDotProductWithInvalidNorm() throws Exception {
DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b
.field("type", "dense_vector")
.field("dims", 3)
.field("index", true)
.field("similarity", VectorSimilarity.dot_product)));
float[] vector = {-12.1f, 2.7f, -4};
MapperParsingException e = expectThrows(MapperParsingException.class, () -> mapper.parse(source(b -> b.array("field", vector))));
assertNotNull(e.getCause());
assertThat(e.getCause().getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors. " +
"Preview of invalid vector: [-12.1, 2.7, -4.0]"));

DocumentMapper mapperWithLargerDim = createDocumentMapper(fieldMapping(b -> b
.field("type", "dense_vector")
.field("dims", 6)
.field("index", true)
.field("similarity", VectorSimilarity.dot_product)));
float[] largerVector = {-12.1f, 2.7f, -4, 1.05f, 10.0f, 29.9f};
e = expectThrows(MapperParsingException.class, () -> mapperWithLargerDim.parse(source(b -> b.array("field", largerVector))));
assertNotNull(e.getCause());
assertThat(e.getCause().getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors. " +
"Preview of invalid vector: [-12.1, 2.7, -4.0, 1.05, 10.0, ...]"));
}

public void testInvalidParameters() {
MapperParsingException e = expectThrows(MapperParsingException.class,
() -> createDocumentMapper(fieldMapping(b -> b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,70 @@

import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.FieldTypeTestCase;
import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper.DenseVectorFieldType;
import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper.VectorSimilarity;

import java.io.IOException;
import java.util.Collections;
import java.util.List;

import static org.hamcrest.Matchers.containsString;

public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
private final boolean indexed;

public DenseVectorFieldTypeTests() {
this.indexed = randomBoolean();
}

private DenseVectorFieldMapper.DenseVectorFieldType createFieldType() {
return new DenseVectorFieldMapper.DenseVectorFieldType("f", Version.CURRENT, 5, indexed, Collections.emptyMap());
private DenseVectorFieldType createFieldType() {
return new DenseVectorFieldType("f", Version.CURRENT, 5, indexed, VectorSimilarity.cosine, Collections.emptyMap());
}

public void testHasDocValues() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
assertNotEquals(indexed, ft.hasDocValues());
}

public void testIsSearchable() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
assertEquals(indexed, ft.isSearchable());
}

public void testIsAggregatable() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
assertFalse(ft.isAggregatable());
}

public void testFielddataBuilder() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
assertNotNull(ft.fielddataBuilder("index", () -> {
throw new UnsupportedOperationException();
}));
}

public void testDocValueFormat() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
expectThrows(IllegalArgumentException.class, () -> ft.docValueFormat(null, null));
}

public void testFetchSourceValue() throws IOException {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
List<Double> vector = List.of(0.0, 1.0, 2.0, 3.0, 4.0);
assertEquals(vector, fetchSourceValue(ft, vector));
}

public void testCreateKnnQuery() {
DenseVectorFieldType unindexedField = new DenseVectorFieldType("f", Version.CURRENT,
3, false, VectorSimilarity.cosine, Collections.emptyMap());
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> unindexedField.createKnnQuery(
new float[]{0.3f, 0.1f, 1.0f}, 10));
assertThat(e.getMessage(), containsString("[knn] queries are not supported if [index] is disabled"));

DenseVectorFieldType dotProductField = new DenseVectorFieldType("f", Version.CURRENT,
3, true, VectorSimilarity.dot_product, Collections.emptyMap());
e = expectThrows(IllegalArgumentException.class, () -> dotProductField.createKnnQuery(
new float[]{0.3f, 0.1f, 1.0f}, 10));
assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors."));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,6 @@ public void testWrongFieldType() {
assertThat(e.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields"));
}

public void testUnindexedField() {
SearchExecutionContext context = createSearchExecutionContext();
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(UNINDEXED_VECTOR_FIELD,
new float[]{1.0f, 1.0f, 1.0f}, 10);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
assertThat(e.getMessage(), containsString("[knn] queries are not supported if [index] is disabled"));
}

@Override
public void testValidOutput() {
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] {1.0f, 2.0f, 3.0f}, 10);
Expand Down