-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Validate zero vector when using cosine metric #1501
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: panguixin <panguixin@bytedance.com>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.common; | ||
|
||
public class KNNVectorUtil { | ||
private KNNVectorUtil() {} | ||
|
||
/** | ||
* Check if all the elements of a given vector are zero | ||
* | ||
* @param vector the vector | ||
* @return true if yes; otherwise false | ||
*/ | ||
public static boolean isZeroVector(byte[] vector) { | ||
for (byte e : vector) { | ||
if (e != (byte) 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: Do we need 0 cast? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not necessary, it's just my habit. I can remove it. |
||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
/** | ||
* Check if all the elements of a given vector are zero | ||
* | ||
* @param vector the vector | ||
* @return true if yes; otherwise false | ||
*/ | ||
public static boolean isZeroVector(float[] vector) { | ||
for (float e : vector) { | ||
if (e != 0f) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import lombok.Getter; | ||
import lombok.extern.log4j.Log4j2; | ||
import org.opensearch.Version; | ||
import org.opensearch.common.Nullable; | ||
import org.opensearch.common.ValidationException; | ||
import org.opensearch.knn.common.KNNConstants; | ||
|
||
|
@@ -35,6 +36,7 @@ | |
import org.opensearch.knn.index.KNNMethodContext; | ||
import org.opensearch.knn.index.KNNSettings; | ||
import org.opensearch.knn.index.KNNVectorIndexFieldData; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.VectorDataType; | ||
import org.opensearch.knn.index.VectorField; | ||
import org.opensearch.knn.index.util.KNNEngine; | ||
|
@@ -55,6 +57,8 @@ | |
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; | ||
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; | ||
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; | ||
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVector; | ||
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector; | ||
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; | ||
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; | ||
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; | ||
|
@@ -281,7 +285,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { | |
|
||
return new ModelFieldMapper( | ||
name, | ||
new KNNVectorFieldType(buildFullName(context), metaValue, -1, knnMethodContext, modelIdAsString), | ||
new KNNVectorFieldType(buildFullName(context), metaValue, -1, modelIdAsString), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we are making this change? and is this required? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would recommend keeping it as is. This will ensure that changes are minimal for this PR. |
||
multiFieldsBuilder, | ||
copyToBuilder, | ||
ignoreMalformed, | ||
|
@@ -313,7 +317,13 @@ public KNNVectorFieldMapper build(BuilderContext context) { | |
|
||
return new LegacyFieldMapper( | ||
name, | ||
new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue(), vectorDataType.getValue()), | ||
new KNNVectorFieldType( | ||
buildFullName(context), | ||
metaValue, | ||
dimension.getValue(), | ||
vectorDataType.getValue(), | ||
SpaceType.getSpace(spaceType) | ||
), | ||
multiFieldsBuilder, | ||
copyToBuilder, | ||
ignoreMalformed, | ||
|
@@ -384,17 +394,24 @@ public static class KNNVectorFieldType extends MappedFieldType { | |
String modelId; | ||
KNNMethodContext knnMethodContext; | ||
VectorDataType vectorDataType; | ||
SpaceType spaceType; | ||
|
||
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, VectorDataType vectorDataType) { | ||
this(name, meta, dimension, null, null, vectorDataType); | ||
public KNNVectorFieldType( | ||
String name, | ||
Map<String, String> meta, | ||
int dimension, | ||
VectorDataType vectorDataType, | ||
SpaceType spaceType | ||
) { | ||
this(name, meta, dimension, null, null, vectorDataType, spaceType); | ||
} | ||
|
||
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext) { | ||
this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD); | ||
this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType()); | ||
} | ||
|
||
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { | ||
this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD); | ||
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, String modelId) { | ||
this(name, meta, dimension, null, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null); | ||
} | ||
|
||
public KNNVectorFieldType( | ||
|
@@ -404,22 +421,24 @@ public KNNVectorFieldType( | |
KNNMethodContext knnMethodContext, | ||
VectorDataType vectorDataType | ||
) { | ||
this(name, meta, dimension, knnMethodContext, null, vectorDataType); | ||
this(name, meta, dimension, knnMethodContext, null, vectorDataType, knnMethodContext.getSpaceType()); | ||
} | ||
|
||
public KNNVectorFieldType( | ||
String name, | ||
Map<String, String> meta, | ||
int dimension, | ||
KNNMethodContext knnMethodContext, | ||
String modelId, | ||
VectorDataType vectorDataType | ||
@Nullable KNNMethodContext knnMethodContext, | ||
@Nullable String modelId, | ||
VectorDataType vectorDataType, | ||
@Nullable SpaceType spaceType | ||
navneet1v marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) { | ||
super(name, false, false, true, TextSearchInfo.NONE, meta); | ||
this.dimension = dimension; | ||
this.modelId = modelId; | ||
this.knnMethodContext = knnMethodContext; | ||
this.vectorDataType = vectorDataType; | ||
this.spaceType = spaceType; | ||
} | ||
|
||
@Override | ||
|
@@ -496,34 +515,35 @@ protected String contentType() { | |
|
||
@Override | ||
protected void parseCreateField(ParseContext context) throws IOException { | ||
parseCreateField(context, fieldType().getDimension()); | ||
parseCreateField(context, fieldType().getDimension(), fieldType().getSpaceType()); | ||
} | ||
|
||
protected void parseCreateField(ParseContext context, int dimension) throws IOException { | ||
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException { | ||
|
||
validateIfKNNPluginEnabled(); | ||
validateIfCircuitBreakerIsNotTriggered(); | ||
|
||
if (VectorDataType.BYTE == vectorDataType) { | ||
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension); | ||
|
||
if (!bytesArrayOptional.isPresent()) { | ||
if (bytesArrayOptional.isEmpty()) { | ||
return; | ||
} | ||
final byte[] array = bytesArrayOptional.get(); | ||
validateByteVector(array, spaceType); | ||
VectorField point = new VectorField(name(), array, fieldType); | ||
|
||
context.doc().add(point); | ||
addStoredFieldForVectorField(context, fieldType, name(), point.toString()); | ||
} else if (VectorDataType.FLOAT == vectorDataType) { | ||
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension); | ||
|
||
if (!floatsArrayOptional.isPresent()) { | ||
if (floatsArrayOptional.isEmpty()) { | ||
return; | ||
} | ||
final float[] array = floatsArrayOptional.get(); | ||
validateFloatVector(array, spaceType); | ||
VectorField point = new VectorField(name(), array, fieldType); | ||
|
||
context.doc().add(point); | ||
addStoredFieldForVectorField(context, fieldType, name(), point.toString()); | ||
} else { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import org.apache.lucene.index.DocValuesType; | ||
import org.opensearch.index.mapper.ParametrizedFieldMapper; | ||
import org.opensearch.index.mapper.ParseContext; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.VectorDataType; | ||
import org.opensearch.knn.index.util.KNNEngine; | ||
|
||
|
@@ -24,6 +25,7 @@ | |
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; | ||
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; | ||
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; | ||
import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector; | ||
|
||
public class KNNVectorFieldMapperUtil { | ||
/** | ||
|
@@ -74,6 +76,34 @@ public static void validateByteVectorValue(float value) { | |
} | ||
} | ||
|
||
/** | ||
* Validate if the given byte vector is supported by the given space type | ||
* | ||
* @param vector the given vector | ||
* @param spaceType the given space type | ||
*/ | ||
public static void validateByteVector(byte[] vector, SpaceType spaceType) { | ||
if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) { | ||
throw new IllegalArgumentException( | ||
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue()) | ||
); | ||
} | ||
} | ||
|
||
/** | ||
* Validate if the given float vector is supported by the given space type | ||
* | ||
* @param vector the given vector | ||
* @param spaceType the given space type | ||
*/ | ||
public static void validateFloatVector(float[] vector, SpaceType spaceType) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we make these parts of SpaceType enum class? Also, does these validations true for all engines? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the zero vector validation is same for all engines, but maybe not true for other validations introduced in the future. For example, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if its true for all engines that it makes more sense to add as part of enum class. |
||
if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) { | ||
throw new IllegalArgumentException( | ||
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue()) | ||
); | ||
} | ||
} | ||
|
||
/** | ||
* Validate if the given vector size matches with the dimension provided in mapping. | ||
* | ||
|
@@ -85,7 +115,6 @@ public static void validateVectorDimension(int dimension, int vectorSize) { | |
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); | ||
throw new IllegalArgumentException(errorMessage); | ||
} | ||
|
||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
import org.opensearch.index.mapper.NumberFieldMapper; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.knn.index.KNNMethodContext; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.VectorDataType; | ||
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; | ||
import org.opensearch.knn.index.util.KNNEngine; | ||
|
@@ -34,7 +35,9 @@ | |
import java.util.Objects; | ||
|
||
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; | ||
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVector; | ||
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; | ||
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems these utility class should be renamed or these functions which are shared between mapper and Query should be put in KNNValidationUtil to ensure that these functions can be shared between Query and Mapper classes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sense. I can pull these out into a separate |
||
|
||
/** | ||
* Helper class to build the KNN query | ||
|
@@ -284,12 +287,15 @@ protected Query doToQuery(QueryShardContext context) { | |
KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); | ||
KNNEngine knnEngine = KNNEngine.DEFAULT; | ||
VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); | ||
SpaceType spaceType = knnVectorFieldType.getSpaceType(); | ||
|
||
if (fieldDimension == -1) { | ||
// If dimension is not set, the field uses a model and the information needs to be retrieved from there | ||
assert spaceType == null; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); | ||
fieldDimension = modelMetadata.getDimension(); | ||
knnEngine = modelMetadata.getKnnEngine(); | ||
spaceType = modelMetadata.getSpaceType(); | ||
} else if (knnMethodContext != null) { | ||
// If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping | ||
knnEngine = knnMethodContext.getKnnEngine(); | ||
|
@@ -308,6 +314,9 @@ protected Query doToQuery(QueryShardContext context) { | |
validateByteVectorValue(vector[i]); | ||
byteVector[i] = (byte) vector[i]; | ||
} | ||
validateByteVector(byteVector, spaceType); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do this validation while reading the vectors in above loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we cannot reuse the validation function or short circuit like |
||
} else { | ||
validateFloatVector(vector, spaceType); | ||
} | ||
|
||
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use @NoArgsConstructor(access = Access.PRIVATE)