Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate zero vector when using cosine metric #1501

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
validate zero vector when using cosine metric
Signed-off-by: panguixin <panguixin@bytedance.com>
  • Loading branch information
bugmakerrrrrr committed Mar 14, 2024
commit d396277395f91cb961178f3a9f780ed96cfdea94
40 changes: 40 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNVectorUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.common;

public class KNNVectorUtil {
private KNNVectorUtil() {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use @NoArgsConstructor(access = Access.PRIVATE)


/**
* Check if all the elements of a given vector are zero
*
* @param vector the vector
* @return true if yes; otherwise false
*/
public static boolean isZeroVector(byte[] vector) {
for (byte e : vector) {
if (e != (byte) 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: Do we need 0 cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary, it's just my habit. I can remove it.

return false;
}
}
return true;
}

/**
* Check if all the elements of a given vector are zero
*
* @param vector the vector
* @return true if yes; otherwise false
*/
public static boolean isZeroVector(float[] vector) {
for (float e : vector) {
if (e != 0f) {
return false;
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.common.Nullable;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;

Expand Down Expand Up @@ -35,6 +36,7 @@
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.index.util.KNNEngine;
Expand All @@ -55,6 +57,8 @@
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
Expand Down Expand Up @@ -281,7 +285,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {

return new ModelFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), metaValue, -1, knnMethodContext, modelIdAsString),
new KNNVectorFieldType(buildFullName(context), metaValue, -1, modelIdAsString),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we are making this change? and is this required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

knnMethodContext is always null for ModelFieldMapper, I believe that this change will make it more clear

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend keeping it as is. This will ensure that changes are minimal for this PR.

multiFieldsBuilder,
copyToBuilder,
ignoreMalformed,
Expand Down Expand Up @@ -313,7 +317,13 @@ public KNNVectorFieldMapper build(BuilderContext context) {

return new LegacyFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue(), vectorDataType.getValue()),
new KNNVectorFieldType(
buildFullName(context),
metaValue,
dimension.getValue(),
vectorDataType.getValue(),
SpaceType.getSpace(spaceType)
),
multiFieldsBuilder,
copyToBuilder,
ignoreMalformed,
Expand Down Expand Up @@ -384,17 +394,24 @@ public static class KNNVectorFieldType extends MappedFieldType {
String modelId;
KNNMethodContext knnMethodContext;
VectorDataType vectorDataType;
SpaceType spaceType;

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, VectorDataType vectorDataType) {
this(name, meta, dimension, null, null, vectorDataType);
public KNNVectorFieldType(
String name,
Map<String, String> meta,
int dimension,
VectorDataType vectorDataType,
SpaceType spaceType
) {
this(name, meta, dimension, null, null, vectorDataType, spaceType);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext) {
this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD);
this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType());
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext, String modelId) {
this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD);
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, String modelId) {
this(name, meta, dimension, null, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null);
}

public KNNVectorFieldType(
Expand All @@ -404,22 +421,24 @@ public KNNVectorFieldType(
KNNMethodContext knnMethodContext,
VectorDataType vectorDataType
) {
this(name, meta, dimension, knnMethodContext, null, vectorDataType);
this(name, meta, dimension, knnMethodContext, null, vectorDataType, knnMethodContext.getSpaceType());
}

public KNNVectorFieldType(
String name,
Map<String, String> meta,
int dimension,
KNNMethodContext knnMethodContext,
String modelId,
VectorDataType vectorDataType
@Nullable KNNMethodContext knnMethodContext,
@Nullable String modelId,
VectorDataType vectorDataType,
@Nullable SpaceType spaceType
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.dimension = dimension;
this.modelId = modelId;
this.knnMethodContext = knnMethodContext;
this.vectorDataType = vectorDataType;
this.spaceType = spaceType;
}

@Override
Expand Down Expand Up @@ -496,34 +515,35 @@ protected String contentType() {

@Override
protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, fieldType().getDimension());
parseCreateField(context, fieldType().getDimension(), fieldType().getSpaceType());
}

protected void parseCreateField(ParseContext context, int dimension) throws IOException {
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);

if (!bytesArrayOptional.isPresent()) {
if (bytesArrayOptional.isEmpty()) {
return;
}
final byte[] array = bytesArrayOptional.get();
validateByteVector(array, spaceType);
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

if (!floatsArrayOptional.isPresent()) {
if (floatsArrayOptional.isEmpty()) {
return;
}
final float[] array = floatsArrayOptional.get();
validateFloatVector(array, spaceType);
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.index.DocValuesType;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

Expand All @@ -24,6 +25,7 @@
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector;

public class KNNVectorFieldMapperUtil {
/**
Expand Down Expand Up @@ -74,6 +76,34 @@ public static void validateByteVectorValue(float value) {
}
}

/**
* Validate if the given byte vector is supported by the given space type
*
* @param vector the given vector
* @param spaceType the given space type
*/
public static void validateByteVector(byte[] vector, SpaceType spaceType) {
if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue())
);
}
}

/**
* Validate if the given float vector is supported by the given space type
*
* @param vector the given vector
* @param spaceType the given space type
*/
public static void validateFloatVector(float[] vector, SpaceType spaceType) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make these parts of SpaceType enum class? Also, does these validations true for all engines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the zero vector validation is same for all engines, but maybe not true for other validations introduced in the future. For example, the dot product metric in Lucene requires all vectors must be normalized, but Faiss doesn't, so I tend to keep these part in this class, any thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if its true for all engines that it makes more sense to add as part of enum class.

if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue())
);
}
}

/**
* Validate if the given vector size matches with the dimension provided in mapping.
*
Expand All @@ -85,7 +115,6 @@ public static void validateVectorDimension(int dimension, int vectorSize) {
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize);
throw new IllegalArgumentException(errorMessage);
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.common.Explicit;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.index.util.KNNEngine;
Expand All @@ -26,6 +27,8 @@
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector;

/**
* Field mapper for case when Lucene has been set as an engine.
Expand Down Expand Up @@ -75,7 +78,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper {
}

@Override
protected void parseCreateField(ParseContext context, int dimension) throws IOException {
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();
Expand All @@ -86,6 +89,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
return;
}
final byte[] array = bytesArrayOptional.get();
validateByteVector(array, spaceType);
KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType);

context.doc().add(point);
Expand All @@ -101,7 +105,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx
return;
}
final float[] array = floatsArrayOptional.get();

validateFloatVector(array, spaceType);
KnnVectorField point = new KnnVectorField(name(), array, fieldType);

context.doc().add(point);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ protected void parseCreateField(ParseContext context) throws IOException {
);
}

parseCreateField(context, modelMetadata.getDimension());
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
Expand All @@ -34,7 +35,9 @@
import java.util.Objects;

import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems these utility class should be renamed or these functions which are shared between mapper and Query should be put in KNNValidationUtil to ensure that these functions can be shared between Query and Mapper classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense. I can pull these out into a separate KNNValidationUtil class


/**
* Helper class to build the KNN query
Expand Down Expand Up @@ -284,12 +287,15 @@ protected Query doToQuery(QueryShardContext context) {
KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext();
KNNEngine knnEngine = KNNEngine.DEFAULT;
VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType();
SpaceType spaceType = knnVectorFieldType.getSpaceType();

if (fieldDimension == -1) {
// If dimension is not set, the field uses a model and the information needs to be retrieved from there
assert spaceType == null;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Putting assert here is dangerous as this will lead to exception messages not understood by users. Lets make sure that we are making sure proper exceptions are thrown here.

  2. Due to this change the comment location have been moved. Lets ensure the comments is at right position.

ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType);
fieldDimension = modelMetadata.getDimension();
knnEngine = modelMetadata.getKnnEngine();
spaceType = modelMetadata.getSpaceType();
} else if (knnMethodContext != null) {
// If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping
knnEngine = knnMethodContext.getKnnEngine();
Expand All @@ -308,6 +314,9 @@ protected Query doToQuery(QueryShardContext context) {
validateByteVectorValue(vector[i]);
byteVector[i] = (byte) vector[i];
}
validateByteVector(byteVector, spaceType);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this validation while reading the vectors in above loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we cannot reuse the validation function or short circuit like VectorUtil.isZeroVector if we implement it in the loop. I prefer not to do so

} else {
validateFloatVector(vector, spaceType);
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.plugin.script;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNWeight;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -18,6 +19,7 @@
import java.util.Map;
import java.util.function.BiFunction;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType;
Expand Down Expand Up @@ -90,14 +92,15 @@ class CosineSimilarity implements KNNScoringSpace {
*/
public CosineSimilarity(Object query, MappedFieldType fieldType) {
if (!isKNNVectorFieldType(fieldType)) {
throw new IllegalArgumentException("Incompatible field_type for cosine space. The field type must " + "be knn_vector.");
throw new IllegalArgumentException("Incompatible field_type for cosine space. The field type must be knn_vector.");
}

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
validateFloatVector(processedQuery, SpaceType.COSINESIMIL);
float qVectorSquaredMagnitude = getVectorMagnitudeSquared(this.processedQuery);
this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude);
}
Expand Down
Loading