Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix few bugs on binary index with Faiss HNSW #1850

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,19 @@ public static ValidationException validateKnnField(
return exception;
}

String vectorDataType = (String) fieldMap.get(VECTOR_DATA_TYPE_FIELD);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can fieldMap.get(VECTOR_DATA_TYPE_FIELD) be null?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be null.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it can be null then the type casting to string will cause a NPE. Please handle this gracefully.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't cause a NPE.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?? casting a null to string cause NPE. I am missing something here.

Copy link
Collaborator Author

@heemin32 heemin32 Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it and it didn't cause NPE. String is of Object and String can be null.

if (VectorDataType.BINARY.toString().equalsIgnoreCase(vectorDataType)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should convert the datatype to enum and then use == in if condition

Copy link
Collaborator Author

@heemin32 heemin32 Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because vectorDataType can be null, I choose this approach. Why should we convert it to the data type?
FYI, this line of code will be removed once we support binary index for IVF.

exception.addValidationError(
String.format(
Locale.ROOT,
"Field \"%s\" is of data type %s. Only FLOAT or BYTE is supported.",
field,
VectorDataType.BINARY
)
);
return exception;
}

// Return if dimension does not need to be checked
if (expectedDimension < 0) {
return null;
Expand Down
9 changes: 7 additions & 2 deletions src/main/java/org/opensearch/knn/index/KNNMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

Expand Down Expand Up @@ -57,8 +58,10 @@ public ValidationException validate(KNNMethodContext knnMethodContext) {
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
errorMessages.add(
String.format(
"\"%s\" configuration does not support space type: " + "\"%s\".",
Locale.ROOT,
"\"%s\" with \"%s\" configuration does not support space type: " + "\"%s\".",
this.methodComponent.getName(),
knnMethodContext.getKnnEngine().getName().toLowerCase(Locale.ROOT),
knnMethodContext.getSpaceType().getValue()
)
);
Expand Down Expand Up @@ -90,8 +93,10 @@ public ValidationException validateWithData(KNNMethodContext knnMethodContext, V
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
errorMessages.add(
String.format(
"\"%s\" configuration does not support space type: " + "\"%s\".",
Locale.ROOT,
"\"%s\" with \"%s\" configuration does not support space type: " + "\"%s\".",
this.methodComponent.getName(),
knnMethodContext.getKnnEngine().getName().toLowerCase(Locale.ROOT),
knnMethodContext.getSpaceType().getValue()
)
);
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public class KNNSettings {
*/
public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false;
public static final String INDEX_KNN_DEFAULT_SPACE_TYPE = "l2";
public static final String INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY = "hammingbit";
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_M = 16;
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH = 100;
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION = 100;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {

// Build legacy
if (this.spaceType == null) {
this.spaceType = LegacyFieldMapper.getSpaceType(context.indexSettings());
this.spaceType = LegacyFieldMapper.getSpaceType(context.indexSettings(), vectorDataType.getValue());
}

if (this.m == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
import org.opensearch.knn.index.util.KNNEngine;

Expand Down Expand Up @@ -78,17 +79,19 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() {
);
}

static String getSpaceType(Settings indexSettings) {
static String getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) {
String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey());
if (spaceType == null) {
spaceType = VectorDataType.BINARY == vectorDataType
? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY
: KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE;
log.info(
String.format(
"[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s",
METHOD_PARAMETER_SPACE_TYPE,
KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE
spaceType
)
);
return KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE;
}
return spaceType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,9 @@ protected Query doToQuery(QueryShardContext context) {
String.format(Locale.ROOT, "Engine [%s] does not support radial search", knnEngine)
);
}
if (vectorDataType == VectorDataType.BINARY) {
throw new UnsupportedOperationException(String.format(Locale.ROOT, "Binary data type does not support radial search"));
}
RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,15 @@ private boolean canDoExactSearch(final int filterIdsCount) {
if (isExactSearchThresholdSettingSet(filterThresholdValue)) {
return filterThresholdValue >= filterIdsCount;
}

// if no setting is set, then use the default max distance computation value to see if we can do exact search.
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * knnQuery.getQueryVector().length;
/**
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
* is cheaper than computation cost for non binary vector
*/
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
heemin32 marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +486 to +487
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the fact that we have two vectors. This will create a lot of branching in the code. We should have gone with generics or something with a different type of query. Something similar to lucene.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are already aware of it and captured it in #1810

}

/**
Expand Down
21 changes: 21 additions & 0 deletions src/test/java/org/opensearch/knn/index/IndexUtilTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,27 @@ public void testValidateKnnField_EmptyIndexMetadata() {
assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;"));
}

public void testValidateKnnField_whenBinaryDataType_thenThrowException() {
Map<String, Object> fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "BINARY");
Map<String, Object> top_level_field = Map.of("top_level_field", fieldValues);
Map<String, Object> properties = Map.of("properties", top_level_field);
String field = "top_level_field";
int dimension = 8;

MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao);

assert (Objects.requireNonNull(e).getMessage().contains("is of data type BINARY. Only FLOAT or BYTE is supported"));
}

public void testIsShareableStateContainedInIndex_whenIndexNotModelBased_thenReturnFalse() {
String modelId = null;
KNNEngine knnEngine = KNNEngine.FAISS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,24 +192,38 @@ public void testBuilder_build_fromLegacy() {
ModelDao modelDao = mock(ModelDao.class);
KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT);

SpaceType spaceType = SpaceType.COSINESIMIL;
int m = 17;
int efConstruction = 17;

// Setup settings
Settings settings = Settings.builder()
.put(settings(CURRENT).build())
.put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue())
.put(KNNSettings.KNN_ALGO_PARAM_M, m)
.put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction)
.build();

Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper);

assertNull(knnVectorFieldMapper.modelId);
assertNull(knnVectorFieldMapper.knnMethod);
assertEquals(SpaceType.L2.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType);
}

public void testBuilder_whenKnnFalseWithBinary_thenSetHammingAsDefault() {
// Check legacy is picked up if model context and method context are not set
ModelDao modelDao = mock(ModelDao.class);
KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT);
builder.vectorDataType.setValue(VectorDataType.BINARY);
builder.dimension.setValue(8);

// Setup settings
Settings settings = Settings.builder().put(settings(CURRENT).build()).build();

Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext);
assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper);
assertEquals(SpaceType.HAMMING_BIT.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType);
}

public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,30 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}

public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() {
float[] queryVector = { 1.0f };
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(queryVector)
.maxDistance(MAX_DISTANCE)
.build();
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(8);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
MethodComponentContext methodComponentContext = new MethodComponentContext(
org.opensearch.knn.common.KNNConstants.METHOD_HNSW,
ImmutableMap.of()
);
KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.HAMMING_BIT, methodComponentContext);
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext);
Exception e = expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
assertTrue(e.getMessage().contains("Binary data type does not support radial search"));
}

public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {
// Given
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,74 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

/**
* This test ensure that we do the exact search when threshold settings are correct and not using filteredIds<=K
* condition to do exact search on binary index
* FilteredIdThreshold: 10
* FilteredIdThresholdPct: 10%
* FilteredIdsCount: 6
* liveDocs : null, as there is no deleted documents
* MaxDoc: 100
* K : 1
*/
@SneakyThrows
public void testANNWithFilterQuery_whenExactSearchViaThresholdSettingOnBinaryIndex_thenSuccess() {
knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(10);
byte[] vector = new byte[] { 1, 3 };
int k = 1;
final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 };

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);
when(reader.maxDoc()).thenReturn(100);
when(reader.getLiveDocs()).thenReturn(null);
final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);

when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length));

final KNNQuery query = new KNNQuery(FIELD_NAME, BYTE_QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null, VectorDataType.BINARY);

final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
final Map<String, String> attributesMap = ImmutableMap.of(
KNN_ENGINE,
KNNEngine.FAISS.getName(),
SPACE_TYPE,
SpaceType.HAMMING_BIT.name(),
PARAMETERS,
String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "BHNSW32")
);
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo);
when(fieldInfo.attributes()).thenReturn(attributesMap);
when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.HAMMING_BIT.getValue());
when(fieldInfo.getName()).thenReturn(FIELD_NAME);
when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues);
when(binaryDocValues.advance(0)).thenReturn(0);
BytesRef vectorByteRef = new BytesRef(vector);
when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef);

final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
assertNotNull(knnScorer);
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertNotNull(docIdSetIterator);
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost());

final List<Integer> actualDocIds = new ArrayList<>();
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

@SneakyThrows
public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() {
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
Expand Down
Loading