From 009c8f39d541f3edb2d665e4300257bee0d97d83 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:01:40 -0700 Subject: [PATCH] Fix mode/comp params so parameter overrides work (#2085) PR adds capability to override parameters when specifying mode and compression. In order to do this, I add functionality for creating a deep copy of KNNMethodContext and MethodComponentContext so that we wouldnt overwrite user provided config. Then, re-arranged some of the parameter resolution logic. Signed-off-by: John Mazanec (cherry picked from commit 270ac6a52f2ff9cc9cfa58dbd65333b21f925048) --- .../index/engine/AbstractMethodResolver.java | 185 ++++++++++ .../opensearch/knn/index/engine/Encoder.java | 12 + .../knn/index/engine/EngineResolver.java | 62 ++++ .../knn/index/engine/KNNEngine.java | 10 + .../knn/index/engine/KNNLibrary.java | 2 +- .../knn/index/engine/KNNMethodContext.java | 55 ++- .../knn/index/engine/MethodComponent.java | 5 +- .../index/engine/MethodComponentContext.java | 23 ++ .../knn/index/engine/MethodResolver.java | 36 ++ .../index/engine/ResolvedMethodContext.java | 23 ++ .../knn/index/engine/SpaceTypeResolver.java | 96 +++++ .../knn/index/engine/faiss/Faiss.java | 25 +- .../index/engine/faiss/FaissFlatEncoder.java | 11 + .../index/engine/faiss/FaissHNSWMethod.java | 26 +- .../engine/faiss/FaissHNSWPQEncoder.java | 12 + .../index/engine/faiss/FaissIVFMethod.java | 27 +- .../index/engine/faiss/FaissIVFPQEncoder.java | 12 + .../engine/faiss/FaissMethodResolver.java | 159 ++++++++ .../index/engine/faiss/FaissSQEncoder.java | 12 + .../index/engine/faiss/QFrameBitEncoder.java | 35 ++ .../knn/index/engine/lucene/Lucene.java | 17 + .../index/engine/lucene/LuceneHNSWMethod.java | 19 +- .../engine/lucene/LuceneMethodResolver.java | 106 ++++++ .../index/engine/lucene/LuceneSQEncoder.java | 12 + .../knn/index/engine/nmslib/Nmslib.java | 16 + .../index/engine/nmslib/NmslibHNSWMethod.java | 4 +- .../engine/nmslib/NmslibMethodResolver.java | 69 ++++ .../knn/index/mapper/CompressionLevel.java | 26 +- .../index/mapper/KNNVectorFieldMapper.java | 123 +++---- .../knn/index/mapper/KNNVectorFieldType.java | 5 +- .../knn/index/mapper/LuceneFieldMapper.java | 21 +- .../knn/index/mapper/MethodFieldMapper.java | 19 +- .../knn/index/mapper/ModeBasedResolver.java | 213 ----------- .../mapper/OriginalMappingParameters.java | 2 + .../opensearch/knn/index/util/IndexUtil.java | 5 +- .../plugin/rest/RestTrainModelHandler.java | 45 +-- .../transport/TrainingModelRequest.java | 19 +- .../java/org/opensearch/knn/KNNTestCase.java | 6 + .../index/engine/AbstractKNNLibraryTests.java | 10 + .../engine/AbstractMethodResolverTests.java | 158 ++++++++ .../knn/index/engine/EngineResolverTests.java | 152 ++++++++ .../knn/index/engine/NativeLibraryTests.java | 10 + .../index/engine/SpaceTypeResolverTests.java | 99 +++++ .../engine/faiss/FaissHNSWPQEncoderTests.java | 16 + .../engine/faiss/FaissIVFPQEncoderTests.java | 16 + .../faiss/FaissMethodResolverTests.java | 246 +++++++++++++ .../engine/faiss/FaissSQEncoderTests.java | 16 + .../engine/faiss/QFrameBitEncoderTests.java | 43 +++ .../lucene/LuceneMethodResolverTests.java | 212 +++++++++++ .../engine/lucene/LuceneSQEncoderTests.java | 16 + .../nmslib/NmslibMethodResolverTests.java | 106 ++++++ .../mapper/KNNVectorFieldMapperTests.java | 343 +++++++++++++++++- .../knn/integ/ModeAndCompressionIT.java | 107 ++++-- .../LibraryInitializedSupplierTests.java | 11 + ...TrainingJobRouterTransportActionTests.java | 6 +- .../transport/TrainingModelRequestTests.java | 72 +--- 56 files changed, 2715 insertions(+), 479 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/engine/AbstractMethodResolver.java create mode 100644 src/main/java/org/opensearch/knn/index/engine/EngineResolver.java create mode 100644 src/main/java/org/opensearch/knn/index/engine/MethodResolver.java create mode 100644 src/main/java/org/opensearch/knn/index/engine/ResolvedMethodContext.java create mode 100644 src/main/java/org/opensearch/knn/index/engine/SpaceTypeResolver.java create mode 100644 src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java create mode 100644 src/main/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolver.java create mode 100644 src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibMethodResolver.java delete mode 100644 src/main/java/org/opensearch/knn/index/mapper/ModeBasedResolver.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/AbstractMethodResolverTests.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/SpaceTypeResolverTests.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoderTests.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoderTests.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoderTests.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolverTests.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoderTests.java create mode 100644 src/test/java/org/opensearch/knn/index/engine/nmslib/NmslibMethodResolverTests.java diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractMethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/AbstractMethodResolver.java new file mode 100644 index 000000000..8127a041d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractMethodResolver.java @@ -0,0 +1,185 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; + +/** + * Abstract {@link MethodResolver} with helpful utilitiy functions that can be shared across different + * implementations + */ +public abstract class AbstractMethodResolver implements MethodResolver { + + /** + * Utility method to get the compression level from the context + * + * @param resolvedKnnMethodContext Resolved method context. Should have an encoder set in the params if available + * @return {@link CompressionLevel} Compression level that is configured with the {@link KNNMethodContext} + */ + protected CompressionLevel resolveCompressionLevelFromMethodContext( + KNNMethodContext resolvedKnnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + Map encoderMap + ) { + // If the context is null, the compression is not configured or the encoder is not defined, return not configured + // because the method context does not contain this info + if (isEncoderSpecified(resolvedKnnMethodContext) == false) { + return CompressionLevel.x1; + } + Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext)); + if (encoder == null) { + return CompressionLevel.NOT_CONFIGURED; + } + return encoder.calculateCompressionLevel(getEncoderComponentContext(resolvedKnnMethodContext), knnMethodConfigContext); + } + + protected void resolveMethodParams( + MethodComponentContext methodComponentContext, + KNNMethodConfigContext knnMethodConfigContext, + MethodComponent methodComponent + ) { + Map resolvedParams = MethodComponent.getParameterMapWithDefaultsAdded( + methodComponentContext, + methodComponent, + knnMethodConfigContext + ); + methodComponentContext.getParameters().putAll(resolvedParams); + } + + protected KNNMethodContext initResolvedKNNMethodContext( + KNNMethodContext originalMethodContext, + KNNEngine knnEngine, + SpaceType spaceType, + String methodName + ) { + if (originalMethodContext == null) { + return new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(methodName, new HashMap<>())); + } + return new KNNMethodContext(originalMethodContext); + } + + protected String getEncoderName(KNNMethodContext knnMethodContext) { + if (isEncoderSpecified(knnMethodContext) == false) { + return null; + } + + MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext); + if (methodComponentContext == null) { + return null; + } + + return methodComponentContext.getName(); + } + + protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) { + if (isEncoderSpecified(knnMethodContext) == false) { + return null; + } + + return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER); + } + + /** + * Determine if the encoder parameter is specified + * + * @param knnMethodContext {@link KNNMethodContext} + * @return true is the encoder is specified in the structure; false otherwise + */ + protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) { + return knnMethodContext != null + && knnMethodContext.getMethodComponentContext().getParameters() != null + && knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER); + } + + protected boolean shouldEncoderBeResolved(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { + // The encoder should not be resolved if: + // 1. The encoder is specified + // 2. The compression is x1 + // 3. The compression is not specified and the mode is not disk-based + if (isEncoderSpecified(knnMethodContext)) { + return false; + } + + if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1) { + return false; + } + + if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel()) == false + && Mode.ON_DISK != knnMethodConfigContext.getMode()) { + return false; + } + + if (VectorDataType.FLOAT != knnMethodConfigContext.getVectorDataType()) { + return false; + } + + return true; + } + + protected ValidationException validateNotTrainingContext( + boolean shouldRequireTraining, + KNNEngine knnEngine, + ValidationException validationException + ) { + if (shouldRequireTraining) { + validationException = validationException == null ? new ValidationException() : validationException; + validationException.addValidationError( + String.format(Locale.ROOT, "Cannot use \"%s\" engine from training context", knnEngine.getName()) + ); + } + + return validationException; + } + + protected ValidationException validateCompressionSupported( + CompressionLevel compressionLevel, + Set supportedCompressionLevels, + KNNEngine knnEngine, + ValidationException validationException + ) { + if (CompressionLevel.isConfigured(compressionLevel) && supportedCompressionLevels.contains(compressionLevel) == false) { + validationException = validationException == null ? new ValidationException() : validationException; + validationException.addValidationError( + String.format(Locale.ROOT, "\"%s\" does not support \"%s\" compression", knnEngine.getName(), compressionLevel.getName()) + ); + } + return validationException; + } + + protected ValidationException validateCompressionNotx1WhenOnDisk( + KNNMethodConfigContext knnMethodConfigContext, + ValidationException validationException + ) { + if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1 && knnMethodConfigContext.getMode() == Mode.ON_DISK) { + validationException = validationException == null ? new ValidationException() : validationException; + validationException.addValidationError( + String.format(Locale.ROOT, "Cannot specify \"x1\" compression level when using \"%s\" mode", Mode.ON_DISK.getName()) + ); + } + return validationException; + } + + protected void validateCompressionConflicts(CompressionLevel originalCompressionLevel, CompressionLevel resolvedCompressionLevel) { + if (CompressionLevel.isConfigured(originalCompressionLevel) + && CompressionLevel.isConfigured(resolvedCompressionLevel) + && resolvedCompressionLevel != originalCompressionLevel) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("Cannot specify an encoder that conflicts with the provided compression level"); + throw validationException; + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/Encoder.java b/src/main/java/org/opensearch/knn/index/engine/Encoder.java index 7e22145eb..f15d0afcf 100644 --- a/src/main/java/org/opensearch/knn/index/engine/Encoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/Encoder.java @@ -5,6 +5,8 @@ package org.opensearch.knn.index.engine; +import org.opensearch.knn.index.mapper.CompressionLevel; + /** * Interface representing an encoder. An encoder generally refers to a vector quantizer. */ @@ -24,4 +26,14 @@ default String getName() { * @return Method component associated with the encoder */ MethodComponent getMethodComponent(); + + /** + * Calculate the compression level for the give params. Assume float32 vectors are used. All parameters should + * be resolved in the encoderContext passed in. + * + * @param encoderContext Context for the encoder to extract params from + * @return Compression level this encoder produces. If the encoder does not support this calculation yet, it will + * return {@link CompressionLevel#NOT_CONFIGURED} + */ + CompressionLevel calculateCompressionLevel(MethodComponentContext encoderContext, KNNMethodConfigContext knnMethodConfigContext); } diff --git a/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java b/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java new file mode 100644 index 000000000..daae361e4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +/** + * Figures out what {@link KNNEngine} to use based on configuration details + */ +public final class EngineResolver { + + public static final EngineResolver INSTANCE = new EngineResolver(); + + private EngineResolver() {} + + /** + * Based on the provided {@link Mode} and {@link CompressionLevel}, resolve to a {@link KNNEngine}. + * + * @param knnMethodConfigContext configuration context + * @param knnMethodContext KNNMethodContext + * @param requiresTraining whether config requires training + * @return {@link KNNEngine} + */ + public KNNEngine resolveEngine( + KNNMethodConfigContext knnMethodConfigContext, + KNNMethodContext knnMethodContext, + boolean requiresTraining + ) { + // User configuration gets precedence + if (knnMethodContext != null && knnMethodContext.isEngineConfigured()) { + return knnMethodContext.getKnnEngine(); + } + + // Faiss is the only engine that supports training, so we default to faiss here for now + if (requiresTraining) { + return KNNEngine.FAISS; + } + + Mode mode = knnMethodConfigContext.getMode(); + CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel(); + // If both mode and compression are not specified, we can just default + if (Mode.isConfigured(mode) == false && CompressionLevel.isConfigured(compressionLevel) == false) { + return KNNEngine.DEFAULT; + } + + // For 1x, we need to default to faiss if mode is provided and use nmslib otherwise + if (CompressionLevel.isConfigured(compressionLevel) == false || compressionLevel == CompressionLevel.x1) { + return mode == Mode.ON_DISK ? KNNEngine.FAISS : KNNEngine.DEFAULT; + } + + // Lucene is only engine that supports 4x - so we have to default to it here. + if (compressionLevel == CompressionLevel.x4) { + return KNNEngine.LUCENE; + } + + return KNNEngine.FAISS; + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index 2f3cb3430..80b9f32a6 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -201,4 +201,14 @@ public void setInitialized(Boolean isInitialized) { public List mmapFileExtensions() { return knnLibrary.mmapFileExtensions(); } + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + final SpaceType spaceType + ) { + return knnLibrary.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java index 14085243f..cf7c4ad82 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java @@ -14,7 +14,7 @@ /** * KNNLibrary is an interface that helps the plugin communicate with k-NN libraries */ -public interface KNNLibrary { +public interface KNNLibrary extends MethodResolver { /** * Gets the version of the library that is being used. In general, this can be used for ensuring compatibility of diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java index 8b2f00f74..4a4c2704e 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.engine; +import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NonNull; @@ -34,17 +35,48 @@ * KNNMethodContext will contain the information necessary to produce a library index from an Opensearch mapping. * It will encompass all parameters necessary to build the index. */ -@AllArgsConstructor +@AllArgsConstructor(access = AccessLevel.PACKAGE) @Getter public class KNNMethodContext implements ToXContentFragment, Writeable { @NonNull - private final KNNEngine knnEngine; + private KNNEngine knnEngine; @NonNull @Setter private SpaceType spaceType; @NonNull private final MethodComponentContext methodComponentContext; + // Currently, the KNNEngine member variable cannot be null and defaults during parsing to nmslib. However, in order + // to support disk based engine resolution, this value potentially needs to be updated. Thus, this value is used + // to determine if the variable can be overridden or not based on whether the user explicitly set the value during parsing + private boolean isEngineConfigured; + + /** + * Copy constructor. Useful for creating a deep copy of a {@link KNNMethodContext}. Note that the engine and + * space type should be set. + * + * @param knnMethodContext original {@link KNNMethodContext}. Must NOT be null + */ + public KNNMethodContext(KNNMethodContext knnMethodContext) { + if (knnMethodContext == null) { + throw new IllegalArgumentException("KNNMethodContext cannot be null"); + } + + this.knnEngine = knnMethodContext.knnEngine; + this.spaceType = knnMethodContext.spaceType; + this.isEngineConfigured = true; + this.methodComponentContext = new MethodComponentContext(knnMethodContext.methodComponentContext); + } + + /** + * + * @param knnEngine {@link KNNEngine} + * @param spaceType {@link SpaceType} + * @param methodComponentContext {@link MethodComponentContext} + */ + public KNNMethodContext(KNNEngine knnEngine, SpaceType spaceType, MethodComponentContext methodComponentContext) { + this(knnEngine, spaceType, methodComponentContext, true); + } /** * Constructor from stream. @@ -56,6 +88,21 @@ public KNNMethodContext(StreamInput in) throws IOException { this.knnEngine = KNNEngine.getEngine(in.readString()); this.spaceType = SpaceType.getSpace(in.readString()); this.methodComponentContext = new MethodComponentContext(in); + this.isEngineConfigured = true; + } + + /** + * Set the {@link KNNEngine} if it is not configured (i.e. DEFAULT). This is useful for using different engines + * for different configurations - i.e. dynamic defaults + * + * @param knnEngine KNNEngine to set + */ + public void setKnnEngine(KNNEngine knnEngine) { + if (isEngineConfigured) { + throw new IllegalArgumentException("Cannot configure KNNEngine if it has already been configured"); + } + this.knnEngine = knnEngine; + this.isEngineConfigured = true; } /** @@ -101,6 +148,7 @@ public static KNNMethodContext parse(Object in) { @SuppressWarnings("unchecked") Map methodMap = (Map) in; + boolean isEngineConfigured = false; KNNEngine engine = KNNEngine.DEFAULT; // Get or default SpaceType spaceType = SpaceType.UNDEFINED; // Get or default String name = ""; @@ -123,6 +171,7 @@ public static KNNMethodContext parse(Object in) { throw new MapperParsingException("Invalid " + KNN_ENGINE + ": " + value); } } + isEngineConfigured = true; } else if (METHOD_PARAMETER_SPACE_TYPE.equals(key)) { if (value != null && !(value instanceof String)) { throw new MapperParsingException("\"" + METHOD_PARAMETER_SPACE_TYPE + "\" must be a string"); @@ -173,7 +222,7 @@ public static KNNMethodContext parse(Object in) { MethodComponentContext method = new MethodComponentContext(name, parameters); - return new KNNMethodContext(engine, spaceType, method); + return new KNNMethodContext(engine, spaceType, method, isEngineConfigured); } @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java b/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java index 2579063e9..75e18a243 100644 --- a/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java +++ b/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java @@ -342,7 +342,10 @@ public static Map getParameterMapWithDefaultsAdded( IndexHyperParametersUtil.getHNSWEFConstructionValue(indexCreationVersion) ); } else { - parametersWithDefaultsMap.put(parameter.getName(), parameter.getDefaultValue()); + Object value = parameter.getDefaultValue(); + if (value != null) { + parametersWithDefaultsMap.put(parameter.getName(), value); + } } } diff --git a/src/main/java/org/opensearch/knn/index/engine/MethodComponentContext.java b/src/main/java/org/opensearch/knn/index/engine/MethodComponentContext.java index 586cc338f..1f0b345e9 100644 --- a/src/main/java/org/opensearch/knn/index/engine/MethodComponentContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/MethodComponentContext.java @@ -49,6 +49,29 @@ public class MethodComponentContext implements ToXContentFragment, Writeable { private final String name; private final Map parameters; + /** + * Copy constructor. Creates a deep copy of a {@link MethodComponentContext} + * + * @param methodComponentContext to be copied. Must NOT be null + */ + public MethodComponentContext(MethodComponentContext methodComponentContext) { + if (methodComponentContext == null) { + throw new IllegalArgumentException("MethodComponentContext cannot be null"); + } + + this.name = methodComponentContext.name; + this.parameters = new HashMap<>(); + if (methodComponentContext.parameters != null) { + for (Map.Entry entry : methodComponentContext.parameters.entrySet()) { + if (entry.getValue() instanceof MethodComponentContext) { + parameters.put(entry.getKey(), new MethodComponentContext((MethodComponentContext) entry.getValue())); + } else { + parameters.put(entry.getKey(), entry.getValue()); + } + } + } + } + /** * Constructor from stream. * diff --git a/src/main/java/org/opensearch/knn/index/engine/MethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/MethodResolver.java new file mode 100644 index 000000000..4df18ad72 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/MethodResolver.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import org.opensearch.knn.index.SpaceType; + +/** + * Interface for resolving the {@link ResolvedMethodContext} for an engine and configuration + */ +public interface MethodResolver { + + /** + * Creates a new {@link ResolvedMethodContext} filling parameters based on other configuration details. A validation + * exception will be thrown if the {@link KNNMethodConfigContext} is not compatible with the + * parameters provided by the user. + * + * @param knnMethodContext User provided information regarding the method context. A new context should be + * constructed. This variable will not be modified. + * @param knnMethodConfigContext Configuration details that can be used for resolving the defaults. Should not be null + * @param shouldRequireTraining Should the provided context require training + * @param spaceType Space type for the method. Cannot be null or undefined + * @return {@link ResolvedMethodContext} with dynamic defaults configured. This will include both the resolved + * compression as well as the completely resolve {@link KNNMethodContext}. + * This is guanteed to be a copy of the user provided context. + * @throws org.opensearch.common.ValidationException on invalid configuration and userprovided context. + */ + ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + final SpaceType spaceType + ); +} diff --git a/src/main/java/org/opensearch/knn/index/engine/ResolvedMethodContext.java b/src/main/java/org/opensearch/knn/index/engine/ResolvedMethodContext.java new file mode 100644 index 000000000..1edc0a98e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/ResolvedMethodContext.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import lombok.Builder; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.knn.index.mapper.CompressionLevel; + +/** + * Small data class for storing info that gets resolved during resolution process + */ +@RequiredArgsConstructor +@Getter +@Builder +public class ResolvedMethodContext { + private final KNNMethodContext knnMethodContext; + @Builder.Default + private final CompressionLevel compressionLevel = CompressionLevel.NOT_CONFIGURED; +} diff --git a/src/main/java/org/opensearch/knn/index/engine/SpaceTypeResolver.java b/src/main/java/org/opensearch/knn/index/engine/SpaceTypeResolver.java new file mode 100644 index 000000000..a12ffbc7b --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/SpaceTypeResolver.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import org.apache.logging.log4j.util.Strings; +import org.opensearch.index.mapper.MapperParsingException; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; + +import java.util.Locale; + +/** + * Class contains the logic to figure out what {@link SpaceType} to use based on configuration + * details. A user can either provide the {@link SpaceType} via the {@link KNNMethodContext} or through a + * top level parameter. This class will take care of this resolution logic (as well as if neither are configured) and + * ensure there are not any contradictions. + */ +public final class SpaceTypeResolver { + + public static final SpaceTypeResolver INSTANCE = new SpaceTypeResolver(); + + private SpaceTypeResolver() {} + + /** + * Resolves space type from configuration details. It is guaranteed not to return either null or + * {@link SpaceType#UNDEFINED} + * + * @param knnMethodContext Method context + * @param vectorDataType Vectordatatype + * @param topLevelSpaceTypeString Alternative top-level space type + * @return {@link SpaceType} for the method + */ + public SpaceType resolveSpaceType( + final KNNMethodContext knnMethodContext, + final VectorDataType vectorDataType, + final String topLevelSpaceTypeString + ) { + SpaceType methodSpaceType = getSpaceTypeFromMethodContext(knnMethodContext); + SpaceType topLevelSpaceType = getSpaceTypeFromString(topLevelSpaceTypeString); + + if (isSpaceTypeConfigured(methodSpaceType) == false && isSpaceTypeConfigured(topLevelSpaceType) == false) { + return getSpaceTypeFromVectorDataType(vectorDataType); + } + + if (isSpaceTypeConfigured(methodSpaceType) == false) { + return topLevelSpaceType; + } + + if (isSpaceTypeConfigured(topLevelSpaceType) == false) { + return methodSpaceType; + } + + if (methodSpaceType == topLevelSpaceType) { + return topLevelSpaceType; + } + + throw new MapperParsingException( + String.format( + Locale.ROOT, + "Cannot specify conflicting space types: \"[%s]\" \"[%s]\"", + methodSpaceType.getValue(), + topLevelSpaceType.getValue() + ) + ); + } + + private SpaceType getSpaceTypeFromMethodContext(final KNNMethodContext knnMethodContext) { + if (knnMethodContext == null) { + return SpaceType.UNDEFINED; + } + + return knnMethodContext.getSpaceType(); + } + + private SpaceType getSpaceTypeFromVectorDataType(final VectorDataType vectorDataType) { + if (vectorDataType == VectorDataType.BINARY) { + return SpaceType.DEFAULT_BINARY; + } + return SpaceType.DEFAULT; + } + + private SpaceType getSpaceTypeFromString(final String spaceType) { + if (Strings.isEmpty(spaceType)) { + return SpaceType.UNDEFINED; + } + + return SpaceType.getSpace(spaceType); + } + + private boolean isSpaceTypeConfigured(final SpaceType spaceType) { + return spaceType != null && spaceType != SpaceType.UNDEFINED; + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index 329acbdb8..a602619a1 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -9,7 +9,11 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNMethod; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodResolver; import org.opensearch.knn.index.engine.NativeLibrary; +import org.opensearch.knn.index.engine.ResolvedMethodContext; import java.util.Map; import java.util.function.Function; @@ -41,12 +45,8 @@ public class Faiss extends NativeLibrary { SpaceType, Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); - private final static Map METHODS = ImmutableMap.of( - METHOD_HNSW, - new FaissHNSWMethod(), - METHOD_IVF, - new FaissIVFMethod() - ); + // Package private so that the method resolving logic can access the methods + final static Map METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod()); public final static Faiss INSTANCE = new Faiss( METHODS, @@ -56,6 +56,8 @@ public class Faiss extends NativeLibrary { SCORE_TO_DISTANCE_TRANSFORMATIONS ); + private final MethodResolver methodResolver; + /** * Constructor for Faiss * @@ -73,6 +75,7 @@ private Faiss( ) { super(methods, scoreTranslation, currentVersion, extension); this.scoreTransform = scoreTransform; + this.methodResolver = new FaissMethodResolver(); } @Override @@ -89,4 +92,14 @@ public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { } return spaceType.scoreToDistanceTranslation(score); } + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + final SpaceType spaceType + ) { + return methodResolver.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java index bd7598d84..f7d4342fc 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissFlatEncoder.java @@ -9,7 +9,10 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.MethodComponent; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.CompressionLevel; import java.util.Set; @@ -41,4 +44,12 @@ public class FaissFlatEncoder implements Encoder { public MethodComponent getMethodComponent() { return METHOD_COMPONENT; } + + @Override + public CompressionLevel calculateCompressionLevel( + MethodComponentContext encoderContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + return CompressionLevel.x1; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index 41db777e3..c153a9328 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -52,12 +53,23 @@ public class FaissHNSWMethod extends AbstractFaissMethod { KNNConstants.ENCODER_FLAT, Collections.emptyMap() ); - private final static List SUPPORTED_ENCODERS = List.of( - new FaissFlatEncoder(), - new FaissSQEncoder(), - new FaissHNSWPQEncoder(), - new QFrameBitEncoder() + + // Package private so that the method resolving logic can access the methods + final static Encoder FLAT_ENCODER = new FaissFlatEncoder(); + final static Encoder SQ_ENCODER = new FaissSQEncoder(); + final static Encoder HNSW_PQ_ENCODER = new FaissHNSWPQEncoder(); + final static Encoder QFRAME_BIT_ENCODER = new QFrameBitEncoder(); + final static Map SUPPORTED_ENCODERS = Map.of( + FLAT_ENCODER.getName(), + FLAT_ENCODER, + SQ_ENCODER.getName(), + SQ_ENCODER, + HNSW_PQ_ENCODER.getName(), + HNSW_PQ_ENCODER, + QFRAME_BIT_ENCODER.getName(), + QFRAME_BIT_ENCODER ); + final static MethodComponent HNSW_COMPONENT = initMethodComponent(); /** * Constructor for FaissHNSWMethod @@ -65,7 +77,7 @@ public class FaissHNSWMethod extends AbstractFaissMethod { * @see AbstractKNNMethod */ public FaissHNSWMethod() { - super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new DefaultHnswSearchContext()); + super(HNSW_COMPONENT, Set.copyOf(SUPPORTED_SPACES), new DefaultHnswSearchContext()); } private static MethodComponent initMethodComponent() { @@ -108,7 +120,7 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter() return new Parameter.MethodComponentContextParameter( METHOD_ENCODER_PARAMETER, DEFAULT_ENCODER_CONTEXT, - SUPPORTED_ENCODERS.stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent)) + SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent)) ); } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java index 9bebf5b4d..6750d84ed 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoder.java @@ -9,8 +9,11 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.MethodComponent; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; +import org.opensearch.knn.index.mapper.CompressionLevel; import java.util.Objects; import java.util.Set; @@ -69,4 +72,13 @@ public class FaissHNSWPQEncoder implements Encoder { public MethodComponent getMethodComponent() { return METHOD_COMPONENT; } + + @Override + public CompressionLevel calculateCompressionLevel( + MethodComponentContext methodComponentContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + // TODO: For now, not supported out of the box + return CompressionLevel.NOT_CONFIGURED; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index 70ab4222b..340c1f4d8 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -19,6 +19,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -55,20 +56,32 @@ public class FaissIVFMethod extends AbstractFaissMethod { KNNConstants.ENCODER_FLAT, Collections.emptyMap() ); - private final static List SUPPORTED_ENCODERS = List.of( - new FaissFlatEncoder(), - new FaissSQEncoder(), - new FaissIVFPQEncoder(), - new QFrameBitEncoder() + + // Package private so that the method resolving logic can access the methods + final static Encoder FLAT_ENCODER = new FaissFlatEncoder(); + final static Encoder SQ_ENCODER = new FaissSQEncoder(); + final static Encoder IVF_PQ_ENCODER = new FaissIVFPQEncoder(); + final static Encoder QFRAME_BIT_ENCODER = new QFrameBitEncoder(); + final static Map SUPPORTED_ENCODERS = Map.of( + FLAT_ENCODER.getName(), + FLAT_ENCODER, + SQ_ENCODER.getName(), + SQ_ENCODER, + IVF_PQ_ENCODER.getName(), + IVF_PQ_ENCODER, + QFRAME_BIT_ENCODER.getName(), + QFRAME_BIT_ENCODER ); + final static MethodComponent IVF_COMPONENT = initMethodComponent(); + /** * Constructor for FaissIVFMethod * * @see AbstractKNNMethod */ public FaissIVFMethod() { - super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new DefaultIVFSearchContext()); + super(IVF_COMPONENT, Set.copyOf(SUPPORTED_SPACES), new DefaultIVFSearchContext()); } private static MethodComponent initMethodComponent() { @@ -133,7 +146,7 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter() return new Parameter.MethodComponentContextParameter( METHOD_ENCODER_PARAMETER, DEFAULT_ENCODER_CONTEXT, - SUPPORTED_ENCODERS.stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent)) + SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent)) ); } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java index bb6623600..8d54548bd 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoder.java @@ -9,8 +9,11 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.MethodComponent; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; +import org.opensearch.knn.index.mapper.CompressionLevel; import java.util.Set; @@ -90,4 +93,13 @@ public class FaissIVFPQEncoder implements Encoder { public MethodComponent getMethodComponent() { return METHOD_COMPONENT; } + + @Override + public CompressionLevel calculateCompressionLevel( + MethodComponentContext methodComponentContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + // TODO: For now, not supported out of the box + return CompressionLevel.NOT_CONFIGURED; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java new file mode 100644 index 000000000..90e938eb3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.AbstractMethodResolver; +import org.opensearch.knn.index.engine.Encoder; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponent; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.index.engine.faiss.FaissHNSWMethod.HNSW_COMPONENT; +import static org.opensearch.knn.index.engine.faiss.FaissIVFMethod.IVF_COMPONENT; + +public class FaissMethodResolver extends AbstractMethodResolver { + + private static final Set SUPPORTED_COMPRESSION_LEVELS = Set.of( + CompressionLevel.x1, + CompressionLevel.x2, + CompressionLevel.x8, + CompressionLevel.x16, + CompressionLevel.x32 + ); + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + final SpaceType spaceType + ) { + // Initial validation to ensure that there are no contradictions in provided parameters + validateConfig(knnMethodConfigContext); + + KNNMethodContext resolvedKNNMethodContext = initResolvedKNNMethodContext( + knnMethodContext, + KNNEngine.FAISS, + spaceType, + shouldRequireTraining ? METHOD_IVF : METHOD_HNSW + ); + MethodComponent method = METHOD_HNSW.equals(resolvedKNNMethodContext.getMethodComponentContext().getName()) == false + ? IVF_COMPONENT + : HNSW_COMPONENT; + Map encoderMap = method == HNSW_COMPONENT ? FaissHNSWMethod.SUPPORTED_ENCODERS : FaissIVFMethod.SUPPORTED_ENCODERS; + + // Fill in parameters for the encoder and then the method. + resolveEncoder(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap); + resolveMethodParams(resolvedKNNMethodContext.getMethodComponentContext(), knnMethodConfigContext, method); + + // From the resolved method context, get the compression level and validate it against the passed in + // configuration + CompressionLevel resolvedCompressionLevel = resolveCompressionLevelFromMethodContext( + resolvedKNNMethodContext, + knnMethodConfigContext, + encoderMap + ); + + // Validate that resolved compression doesnt have any conflicts + validateCompressionConflicts(knnMethodConfigContext.getCompressionLevel(), resolvedCompressionLevel); + return ResolvedMethodContext.builder() + .knnMethodContext(resolvedKNNMethodContext) + .compressionLevel(resolvedCompressionLevel) + .build(); + } + + private void resolveEncoder( + KNNMethodContext resolvedKNNMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + Map encoderMap + ) { + if (shouldEncoderBeResolved(resolvedKNNMethodContext, knnMethodConfigContext) == false) { + return; + } + + CompressionLevel resolvedCompressionLevel = getDefaultCompressionLevel(knnMethodConfigContext); + if (resolvedCompressionLevel == CompressionLevel.x1) { + return; + } + + MethodComponentContext encoderComponentContext = new MethodComponentContext(ENCODER_FLAT, new HashMap<>()); + Encoder encoder = encoderMap.get(ENCODER_FLAT); + if (CompressionLevel.x2 == resolvedCompressionLevel) { + encoderComponentContext = new MethodComponentContext(ENCODER_SQ, new HashMap<>()); + encoder = encoderMap.get(ENCODER_SQ); + encoderComponentContext.getParameters().put(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16); + } + + if (CompressionLevel.x8 == resolvedCompressionLevel) { + encoderComponentContext = new MethodComponentContext(QFrameBitEncoder.NAME, new HashMap<>()); + encoder = encoderMap.get(QFrameBitEncoder.NAME); + encoderComponentContext.getParameters().put(QFrameBitEncoder.BITCOUNT_PARAM, CompressionLevel.x8.numBitsForFloat32()); + } + + if (CompressionLevel.x16 == resolvedCompressionLevel) { + encoderComponentContext = new MethodComponentContext(QFrameBitEncoder.NAME, new HashMap<>()); + encoder = encoderMap.get(QFrameBitEncoder.NAME); + encoderComponentContext.getParameters().put(QFrameBitEncoder.BITCOUNT_PARAM, CompressionLevel.x16.numBitsForFloat32()); + } + + if (CompressionLevel.x32 == resolvedCompressionLevel) { + encoderComponentContext = new MethodComponentContext(QFrameBitEncoder.NAME, new HashMap<>()); + encoder = encoderMap.get(QFrameBitEncoder.NAME); + encoderComponentContext.getParameters().put(QFrameBitEncoder.BITCOUNT_PARAM, CompressionLevel.x32.numBitsForFloat32()); + } + + Map resolvedParams = MethodComponent.getParameterMapWithDefaultsAdded( + encoderComponentContext, + encoder.getMethodComponent(), + knnMethodConfigContext + ); + encoderComponentContext.getParameters().putAll(resolvedParams); + resolvedKNNMethodContext.getMethodComponentContext().getParameters().put(METHOD_ENCODER_PARAMETER, encoderComponentContext); + } + + // Method validates for explicit contradictions in the config + private void validateConfig(KNNMethodConfigContext knnMethodConfigContext) { + CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel(); + ValidationException validationException = validateCompressionSupported( + compressionLevel, + SUPPORTED_COMPRESSION_LEVELS, + KNNEngine.FAISS, + null + ); + validationException = validateCompressionNotx1WhenOnDisk(knnMethodConfigContext, validationException); + if (validationException != null) { + throw validationException; + } + } + + private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMethodConfigContext) { + if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel())) { + return knnMethodConfigContext.getCompressionLevel(); + } + if (knnMethodConfigContext.getMode() == Mode.ON_DISK) { + return CompressionLevel.x32; + } + return CompressionLevel.x1; + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java index 6d57aef2f..cd7e1e5f3 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java @@ -8,8 +8,11 @@ import com.google.common.collect.ImmutableSet; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.MethodComponent; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; +import org.opensearch.knn.index.mapper.CompressionLevel; import java.util.Objects; import java.util.Set; @@ -49,4 +52,13 @@ public class FaissSQEncoder implements Encoder { public MethodComponent getMethodComponent() { return METHOD_COMPONENT; } + + @Override + public CompressionLevel calculateCompressionLevel( + MethodComponentContext methodComponentContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + // TODO: Hard code for now + return CompressionLevel.x2; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoder.java index e135fa33f..2292dc3cc 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoder.java @@ -6,12 +6,16 @@ package org.opensearch.knn.index.engine.faiss; import com.google.common.collect.ImmutableSet; +import org.opensearch.common.ValidationException; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.KNNLibraryIndexingContextImpl; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.MethodComponent; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.util.HashMap; @@ -75,4 +79,35 @@ public class QFrameBitEncoder implements Encoder { public MethodComponent getMethodComponent() { return METHOD_COMPONENT; } + + @Override + public CompressionLevel calculateCompressionLevel( + MethodComponentContext methodComponentContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + if (methodComponentContext.getParameters().containsKey(BITCOUNT_PARAM) == false) { + return CompressionLevel.NOT_CONFIGURED; + } + + // Map the number of bits passed in, back to the compression level + Object value = methodComponentContext.getParameters().get(BITCOUNT_PARAM); + ValidationException validationException = METHOD_COMPONENT.getParameters() + .get(BITCOUNT_PARAM) + .validate(value, knnMethodConfigContext); + if (validationException != null) { + throw validationException; + } + + Integer bitCount = (Integer) value; + if (bitCount == 1) { + return CompressionLevel.x32; + } + + if (bitCount == 2) { + return CompressionLevel.x16; + } + + // Validation will ensure that only 1 of the supported bit count will be selected. + return CompressionLevel.x8; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/Lucene.java b/src/main/java/org/opensearch/knn/index/engine/lucene/Lucene.java index 986380897..db516d309 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/Lucene.java @@ -10,6 +10,10 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.JVMLibrary; import org.opensearch.knn.index.engine.KNNMethod; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodResolver; +import org.opensearch.knn.index.engine.ResolvedMethodContext; import java.util.List; import java.util.Map; @@ -37,6 +41,8 @@ public class Lucene extends JVMLibrary { public final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS); + private final MethodResolver methodResolver; + /** * Constructor * @@ -47,6 +53,7 @@ public class Lucene extends JVMLibrary { Lucene(Map methods, String version, Map> distanceTransform) { super(methods, version); this.distanceTransform = distanceTransform; + this.methodResolver = new LuceneMethodResolver(); } @Override @@ -86,4 +93,14 @@ public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { public List mmapFileExtensions() { return List.of("vec", "vex"); } + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + final SpaceType spaceType + ) { + return methodResolver.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java index 317f67c10..57cc016a6 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java @@ -6,19 +6,17 @@ package org.opensearch.knn.index.engine.lucene; import com.google.common.collect.ImmutableSet; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.AbstractKNNMethod; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.MethodComponent; -import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; import java.util.Arrays; -import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -41,11 +39,10 @@ public class LuceneHNSWMethod extends AbstractKNNMethod { SpaceType.INNER_PRODUCT ); - private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext( - KNNConstants.ENCODER_FLAT, - Collections.emptyMap() - ); - private final static List SUPPORTED_ENCODERS = List.of(new LuceneSQEncoder()); + final static Encoder SQ_ENCODER = new LuceneSQEncoder(); + final static Map SUPPORTED_ENCODERS = Map.of(SQ_ENCODER.getName(), SQ_ENCODER); + + final static MethodComponent HNSW_METHOD_COMPONENT = initMethodComponent(); /** * Constructor for LuceneHNSWMethod @@ -53,7 +50,7 @@ public class LuceneHNSWMethod extends AbstractKNNMethod { * @see AbstractKNNMethod */ public LuceneHNSWMethod() { - super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new LuceneHNSWSearchContext()); + super(HNSW_METHOD_COMPONENT, Set.copyOf(SUPPORTED_SPACES), new LuceneHNSWSearchContext()); } private static MethodComponent initMethodComponent() { @@ -78,8 +75,8 @@ private static MethodComponent initMethodComponent() { private static Parameter.MethodComponentContextParameter initEncoderParameter() { return new Parameter.MethodComponentContextParameter( METHOD_ENCODER_PARAMETER, - DEFAULT_ENCODER_CONTEXT, - SUPPORTED_ENCODERS.stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent)) + null, + SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent)) ); } } diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolver.java new file mode 100644 index 000000000..6546d9f93 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolver.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.lucene; + +import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.AbstractMethodResolver; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponent; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.index.engine.lucene.LuceneHNSWMethod.HNSW_METHOD_COMPONENT; +import static org.opensearch.knn.index.engine.lucene.LuceneHNSWMethod.SQ_ENCODER; + +public class LuceneMethodResolver extends AbstractMethodResolver { + + private static final Set SUPPORTED_COMPRESSION_LEVELS = Set.of(CompressionLevel.x1, CompressionLevel.x4); + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + final SpaceType spaceType + ) { + validateConfig(knnMethodConfigContext, shouldRequireTraining); + KNNMethodContext resolvedKNNMethodContext = initResolvedKNNMethodContext( + knnMethodContext, + KNNEngine.LUCENE, + spaceType, + METHOD_HNSW + ); + resolveEncoder(resolvedKNNMethodContext, knnMethodConfigContext); + resolveMethodParams(resolvedKNNMethodContext.getMethodComponentContext(), knnMethodConfigContext, HNSW_METHOD_COMPONENT); + CompressionLevel resolvedCompressionLevel = resolveCompressionLevelFromMethodContext( + resolvedKNNMethodContext, + knnMethodConfigContext, + LuceneHNSWMethod.SUPPORTED_ENCODERS + ); + validateCompressionConflicts(knnMethodConfigContext.getCompressionLevel(), resolvedCompressionLevel); + return ResolvedMethodContext.builder() + .knnMethodContext(resolvedKNNMethodContext) + .compressionLevel(resolvedCompressionLevel) + .build(); + } + + protected void resolveEncoder(KNNMethodContext resolvedKNNMethodContext, KNNMethodConfigContext knnMethodConfigContext) { + if (shouldEncoderBeResolved(resolvedKNNMethodContext, knnMethodConfigContext) == false) { + return; + } + + CompressionLevel resolvedCompressionLevel = getDefaultCompressionLevel(knnMethodConfigContext); + if (resolvedCompressionLevel == CompressionLevel.x1) { + return; + } + + MethodComponentContext methodComponentContext = resolvedKNNMethodContext.getMethodComponentContext(); + MethodComponentContext encoderComponentContext = new MethodComponentContext(SQ_ENCODER.getName(), new HashMap<>()); + Map resolvedParams = MethodComponent.getParameterMapWithDefaultsAdded( + encoderComponentContext, + SQ_ENCODER.getMethodComponent(), + knnMethodConfigContext + ); + encoderComponentContext.getParameters().putAll(resolvedParams); + methodComponentContext.getParameters().put(METHOD_ENCODER_PARAMETER, encoderComponentContext); + } + + // Method validates for explicit contradictions in the config + private void validateConfig(KNNMethodConfigContext knnMethodConfigContext, boolean shouldRequireTraining) { + ValidationException validationException = validateNotTrainingContext(shouldRequireTraining, KNNEngine.LUCENE, null); + validationException = validateCompressionSupported( + knnMethodConfigContext.getCompressionLevel(), + SUPPORTED_COMPRESSION_LEVELS, + KNNEngine.LUCENE, + validationException + ); + validationException = validateCompressionNotx1WhenOnDisk(knnMethodConfigContext, validationException); + if (validationException != null) { + throw validationException; + } + } + + private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMethodConfigContext) { + if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel())) { + return knnMethodConfigContext.getCompressionLevel(); + } + if (knnMethodConfigContext.getMode() == Mode.ON_DISK) { + return CompressionLevel.x4; + } + return CompressionLevel.x1; + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java index 0ec43db41..6bd16ebee 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java @@ -8,8 +8,11 @@ import com.google.common.collect.ImmutableSet; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.MethodComponent; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; +import org.opensearch.knn.index.mapper.CompressionLevel; import java.util.List; import java.util.Set; @@ -49,4 +52,13 @@ public class LuceneSQEncoder implements Encoder { public MethodComponent getMethodComponent() { return METHOD_COMPONENT; } + + @Override + public CompressionLevel calculateCompressionLevel( + MethodComponentContext methodComponentContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + // Hard coding to 4x for now, given thats all that is supported. + return CompressionLevel.x4; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/nmslib/Nmslib.java b/src/main/java/org/opensearch/knn/index/engine/nmslib/Nmslib.java index d35cc5f6c..4d7f7f423 100644 --- a/src/main/java/org/opensearch/knn/index/engine/nmslib/Nmslib.java +++ b/src/main/java/org/opensearch/knn/index/engine/nmslib/Nmslib.java @@ -8,7 +8,11 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNMethod; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodResolver; import org.opensearch.knn.index.engine.NativeLibrary; +import org.opensearch.knn.index.engine.ResolvedMethodContext; import java.util.Collections; import java.util.Map; @@ -27,6 +31,7 @@ public class Nmslib extends NativeLibrary { final static Map METHODS = ImmutableMap.of(METHOD_HNSW, new NmslibHNSWMethod()); public final static Nmslib INSTANCE = new Nmslib(METHODS, Collections.emptyMap(), CURRENT_VERSION, EXTENSION); + private final MethodResolver methodResolver; /** * Constructor for Nmslib @@ -43,6 +48,7 @@ private Nmslib( String extension ) { super(methods, scoreTranslation, currentVersion, extension); + this.methodResolver = new NmslibMethodResolver(); } @Override @@ -53,4 +59,14 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { return score; } + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + final SpaceType spaceType + ) { + return methodResolver.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java index 779c16cd3..d2440926e 100644 --- a/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java @@ -38,12 +38,14 @@ public class NmslibHNSWMethod extends AbstractKNNMethod { SpaceType.INNER_PRODUCT ); + final static MethodComponent HNSW_METHOD_COMPONENT = initMethodComponent(); + /** * Constructor. Builds the method with the default parameters and supported spaces. * @see AbstractKNNMethod */ public NmslibHNSWMethod() { - super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new DefaultHnswSearchContext()); + super(HNSW_METHOD_COMPONENT, Set.copyOf(SUPPORTED_SPACES), new DefaultHnswSearchContext()); } private static MethodComponent initMethodComponent() { diff --git a/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibMethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibMethodResolver.java new file mode 100644 index 000000000..619a00eda --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibMethodResolver.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.nmslib; + +import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.AbstractMethodResolver; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +import java.util.Set; + +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.index.engine.nmslib.NmslibHNSWMethod.HNSW_METHOD_COMPONENT; + +/** + * Method resolution logic for nmslib. Because nmslib does not support quantization, it is in general a validation + * before returning the original request + */ +public class NmslibMethodResolver extends AbstractMethodResolver { + + private static final Set SUPPORTED_COMPRESSION_LEVELS = Set.of(CompressionLevel.x1); + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + final SpaceType spaceType + ) { + validateConfig(knnMethodConfigContext, shouldRequireTraining); + KNNMethodContext resolvedKNNMethodContext = initResolvedKNNMethodContext( + knnMethodContext, + KNNEngine.NMSLIB, + spaceType, + METHOD_HNSW + ); + resolveMethodParams(resolvedKNNMethodContext.getMethodComponentContext(), knnMethodConfigContext, HNSW_METHOD_COMPONENT); + return ResolvedMethodContext.builder().knnMethodContext(resolvedKNNMethodContext).compressionLevel(CompressionLevel.x1).build(); + } + + // Method validates for explicit contradictions in the config + private void validateConfig(KNNMethodConfigContext knnMethodConfigContext, boolean shouldRequireTraining) { + ValidationException validationException = validateNotTrainingContext(shouldRequireTraining, KNNEngine.NMSLIB, null); + CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel(); + validationException = validateCompressionSupported( + compressionLevel, + SUPPORTED_COMPRESSION_LEVELS, + KNNEngine.NMSLIB, + validationException + ); + + if (Mode.ON_DISK == knnMethodConfigContext.getMode()) { + validationException = validationException == null ? new ValidationException() : validationException; + validationException.addValidationError("Nmslib engine does not support disk-based search"); + } + + if (validationException != null) { + throw validationException; + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java index 4b5026598..cc80bb1ed 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -10,7 +10,9 @@ import org.opensearch.core.common.Strings; import org.opensearch.knn.index.query.rescore.RescoreContext; +import java.util.Collections; import java.util.Locale; +import java.util.Set; /** * Enum representing the compression level for float vectors. Compression in this sense refers to compressing a @@ -19,13 +21,13 @@ */ @AllArgsConstructor public enum CompressionLevel { - NOT_CONFIGURED(-1, "", null), - x1(1, "1x", null), - x2(2, "2x", null), - x4(4, "4x", new RescoreContext(1.0f)), - x8(8, "8x", new RescoreContext(1.5f)), - x16(16, "16x", new RescoreContext(2.0f)), - x32(32, "32x", new RescoreContext(2.0f)); + NOT_CONFIGURED(-1, "", null, Collections.emptySet()), + x1(1, "1x", null, Collections.emptySet()), + x2(2, "2x", null, Collections.emptySet()), + x4(4, "4x", null, Collections.emptySet()), + x8(8, "8x", new RescoreContext(1.5f), Set.of(Mode.ON_DISK)), + x16(16, "16x", new RescoreContext(2.0f), Set.of(Mode.ON_DISK)), + x32(32, "32x", new RescoreContext(2.0f), Set.of(Mode.ON_DISK)); // Internally, an empty string is easier to deal with them null. However, from the mapping, // we do not want users to pass in the empty string and instead want null. So we make the conversion herex @@ -33,6 +35,7 @@ public enum CompressionLevel { NOT_CONFIGURED.getName(), x1.getName(), x2.getName(), + x4.getName(), x8.getName(), x16.getName(), x32.getName() }; @@ -64,8 +67,8 @@ public static CompressionLevel fromName(String name) { private final int compressionLevel; @Getter private final String name; - @Getter private final RescoreContext defaultRescoreContext; + private final Set modesForRescore; /** * Gets the number of bits used to represent a float in order to achieve this compression. For instance, for @@ -90,4 +93,11 @@ public int numBitsForFloat32() { public static boolean isConfigured(CompressionLevel compressionLevel) { return compressionLevel != null && compressionLevel != NOT_CONFIGURED; } + + public RescoreContext getDefaultRescoreContext(Mode mode) { + if (modesForRescore.contains(mode)) { + return defaultRescoreContext; + } + return null; + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index f149fa1d2..6e5138a56 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -39,12 +39,15 @@ import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.engine.EngineResolver; import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.index.engine.SpaceTypeResolver; import org.opensearch.knn.indices.ModelDao; import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; @@ -174,6 +177,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected ModelDao modelDao; protected Version indexCreatedVersion; @Setter + @Getter private KNNMethodConfigContext knnMethodConfigContext; @Setter @Getter @@ -365,9 +369,20 @@ public Mapper.Builder parse(String name, Map node, ParserCont } else if (builder.modelId.get() != null) { validateFromModel(builder); } else { - validateMode(builder); + // Validate that the mode and compression are not set if data type is not float, as they are not + // supported. + validateModeAndCompressionForDataType(builder); + // If the original knnMethodContext is not null, resolve its space type and engine from the rest of the + // configuration. This is consistent with the existing behavior for space type in 2.16 where we modify the + // parsed value + SpaceType resolvedSpaceType = SpaceTypeResolver.INSTANCE.resolveSpaceType( + builder.originalParameters.getKnnMethodContext(), + builder.vectorDataType.get(), + builder.topLevelSpaceType.get() + ); + setSpaceType(builder.originalParameters.getKnnMethodContext(), resolvedSpaceType); validateSpaceType(builder); - resolveKNNMethodComponents(builder, parserContext); + resolveKNNMethodComponents(builder, parserContext, resolvedSpaceType); validateFromKNNMethod(builder); } @@ -391,20 +406,9 @@ private void validateSpaceType(KNNVectorFieldMapper.Builder builder) { } } - private void validateMode(KNNVectorFieldMapper.Builder builder) { - boolean isKNNMethodContextConfigured = builder.originalParameters.getKnnMethodContext() != null; - boolean isModeConfigured = builder.mode.isConfigured() || builder.compressionLevel.isConfigured(); - if (isModeConfigured && isKNNMethodContextConfigured) { - throw new MapperParsingException( - String.format( - Locale.ROOT, - "Compression and mode can not be specified in a \"method\" mapping configuration for field: %s", - builder.name - ) - ); - } - - if (isModeConfigured && builder.vectorDataType.getValue() != VectorDataType.FLOAT) { + private void validateModeAndCompressionForDataType(KNNVectorFieldMapper.Builder builder) { + boolean isModeOrCompressionConfigured = builder.mode.isConfigured() || builder.compressionLevel.isConfigured(); + if (isModeOrCompressionConfigured && builder.vectorDataType.getValue() != VectorDataType.FLOAT) { throw new MapperParsingException( String.format(Locale.ROOT, "Compression and mode cannot be used for non-float32 data type for field %s", builder.name) ); @@ -468,7 +472,12 @@ private void validateCompressionAndModeNotSet(KNNVectorFieldMapper.Builder build } } - private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, ParserContext parserContext) { + private void resolveKNNMethodComponents( + KNNVectorFieldMapper.Builder builder, + ParserContext parserContext, + SpaceType resolvedSpaceType + ) { + // Setup the initial configuration that is used to help resolve parameters. builder.setKnnMethodConfigContext( KNNMethodConfigContext.builder() .vectorDataType(builder.originalParameters.getVectorDataType()) @@ -479,36 +488,34 @@ private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, Pa .build() ); - // Configure method from map or legacy + // If the original parameters are from legacy if (builder.originalParameters.isLegacyMapping()) { + // Then create KNNMethodContext to be used from the legacy index settings builder.originalParameters.setResolvedKnnMethodContext( - createKNNMethodContextFromLegacy( - parserContext.getSettings(), - parserContext.indexVersionCreated(), - SpaceType.getSpace(builder.topLevelSpaceType.get()) - ) + createKNNMethodContextFromLegacy(parserContext.getSettings(), parserContext.indexVersionCreated(), resolvedSpaceType) ); - } else if (Mode.isConfigured(Mode.fromName(builder.mode.get())) - || CompressionLevel.isConfigured(CompressionLevel.fromName(builder.compressionLevel.get()))) { - // we need don't need to resolve the space type, whatever default we are using will be passed down to - // while resolving KNNMethodContext for the mode and compression. and then when we resolve the spaceType - // we will set the correct spaceType. - builder.originalParameters.setResolvedKnnMethodContext( - ModeBasedResolver.INSTANCE.resolveKNNMethodContext( - builder.knnMethodConfigContext.getMode(), - builder.knnMethodConfigContext.getCompressionLevel(), - false, - SpaceType.getSpace(builder.originalParameters.getTopLevelSpaceType()) - ) - ); - } - // this function should now correct the space type for the above resolved context too, if spaceType was - // not provided. - setSpaceType( + } + + // Based on config context, if the user does not set the engine, set it + KNNEngine resolvedKNNEngine = EngineResolver.INSTANCE.resolveEngine( + builder.knnMethodConfigContext, builder.originalParameters.getResolvedKnnMethodContext(), - builder.originalParameters.getVectorDataType(), - builder.topLevelSpaceType.get() + false ); + setEngine(builder.originalParameters.getResolvedKnnMethodContext(), resolvedKNNEngine); + + // Create a copy of the KNNMethodContext and fill in the parameters left blank by configuration context context + ResolvedMethodContext resolvedMethodContext = resolvedKNNEngine.resolveMethod( + builder.originalParameters.getResolvedKnnMethodContext(), + builder.knnMethodConfigContext, + false, + resolvedSpaceType + ); + + // The original parameters stores both the resolveMethodContext as well as the original provided by the + // user. Now that we have resolved, we need to update this in the original parameters. + builder.originalParameters.setResolvedKnnMethodContext(resolvedMethodContext.getKnnMethodContext()); + builder.knnMethodConfigContext.setCompressionLevel(resolvedMethodContext.getCompressionLevel()); } private boolean isKNNDisabled(Settings settings) { @@ -516,32 +523,18 @@ private boolean isKNNDisabled(Settings settings) { return !isSettingPresent || !KNNSettings.IS_KNN_INDEX_SETTING.get(settings); } - private void setSpaceType( - final KNNMethodContext knnMethodContext, - final VectorDataType vectorDataType, - final String topLevelSpaceType - ) { - // Now KNNMethodContext should never be null. Because only case it could be null is flatMapper which is - // already handled + private void setSpaceType(final KNNMethodContext knnMethodContext, final SpaceType spaceType) { if (knnMethodContext == null) { - throw new IllegalArgumentException("KNNMethodContext cannot be null"); + return; } - final SpaceType topLevelSpaceTypeEnum = SpaceType.getSpace(topLevelSpaceType); - // Now set the spaceSpaceType for KNNMethodContext - if (SpaceType.UNDEFINED == knnMethodContext.getSpaceType()) { - // We are handling the case when top level space type is defined but method level spaceType is not - // defined. - if (topLevelSpaceTypeEnum != SpaceType.UNDEFINED) { - knnMethodContext.setSpaceType(topLevelSpaceTypeEnum); - } else { - // If both spaceTypes are undefined then put the default spaceType based on datatype - if (VectorDataType.BINARY == vectorDataType) { - knnMethodContext.setSpaceType(SpaceType.DEFAULT_BINARY); - } else { - knnMethodContext.setSpaceType(SpaceType.DEFAULT); - } - } + knnMethodContext.setSpaceType(spaceType); + } + + private void setEngine(final KNNMethodContext knnMethodContext, KNNEngine knnEngine) { + if (knnMethodContext == null || knnMethodContext.isEngineConfigured()) { + return; } + knnMethodContext.setKnnEngine(knnEngine); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index 963688d0c..e684ba4f1 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -93,9 +93,6 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext) if (userProvidedContext != null) { return userProvidedContext; } - return ModeBasedResolver.INSTANCE.resolveRescoreContext( - getKnnMappingConfig().getMode(), - getKnnMappingConfig().getCompressionLevel() - ); + return getKnnMappingConfig().getCompressionLevel().getDefaultRescoreContext(getKnnMappingConfig().getMode()); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 3da2745ac..fcf3aa034 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -63,6 +63,16 @@ public Optional getKnnMethodContext() { public int getDimension() { return knnMethodConfigContext.getDimension(); } + + @Override + public Mode getMode() { + return knnMethodConfigContext.getMode(); + } + + @Override + public CompressionLevel getCompressionLevel() { + return knnMethodConfigContext.getCompressionLevel(); + } } ); @@ -87,24 +97,23 @@ private LuceneFieldMapper( originalMappingParameters ); KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); - KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() - .orElseThrow(() -> new IllegalArgumentException("KNN method context is missing")); + KNNMethodContext resolvedKnnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); VectorDataType vectorDataType = mappedFieldType.getVectorDataType(); - final VectorSimilarityFunction vectorSimilarityFunction = knnMethodContext.getSpaceType() + final VectorSimilarityFunction vectorSimilarityFunction = resolvedKnnMethodContext.getSpaceType() .getKnnVectorSimilarityFunction() .getVectorSimilarityFunction(); this.fieldType = vectorDataType.createKnnVectorFieldType(knnMappingConfig.getDimension(), vectorSimilarityFunction); if (this.hasDocValues) { - this.vectorFieldType = buildDocValuesFieldType(knnMethodContext.getKnnEngine()); + this.vectorFieldType = buildDocValuesFieldType(resolvedKnnMethodContext.getKnnEngine()); } else { this.vectorFieldType = null; } - KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); + KNNLibraryIndexingContext knnLibraryIndexingContext = resolvedKnnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(resolvedKnnMethodContext, knnMethodConfigContext); this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 3d11949fe..bf5bc2b51 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -74,7 +74,7 @@ public int getDimension() { @Override public Mode getMode() { - return knnMethodConfigContext.getMode(); + return Mode.fromName(originalMappingParameters.getMode()); } @Override @@ -125,19 +125,18 @@ private MethodFieldMapper( originalMappingParameters ); this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); - KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); - KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext() - .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); - KNNEngine knnEngine = knnMethodContext.getKnnEngine(); + KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); + KNNMethodContext resolvedKnnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); + KNNEngine knnEngine = resolvedKnnMethodContext.getKnnEngine(); KNNLibraryIndexingContext knnLibraryIndexingContext = knnEngine.getKNNLibraryIndexingContext( - knnMethodContext, + resolvedKnnMethodContext, knnMethodConfigContext ); QuantizationConfig quantizationConfig = knnLibraryIndexingContext.getQuantizationConfig(); this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - this.fieldType.putAttribute(DIMENSION, String.valueOf(annConfig.getDimension())); - this.fieldType.putAttribute(SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); + this.fieldType.putAttribute(DIMENSION, String.valueOf(knnMappingConfig.getDimension())); + this.fieldType.putAttribute(SPACE_TYPE, resolvedKnnMethodContext.getSpaceType().getValue()); // Conditionally add quantization config if (quantizationConfig != null && quantizationConfig != QuantizationConfig.EMPTY) { this.fieldType.putAttribute(QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); @@ -157,8 +156,8 @@ private MethodFieldMapper( if (useLuceneBasedVectorField) { int adjustedDimension = mappedFieldType.vectorDataType == VectorDataType.BINARY - ? annConfig.getDimension() / 8 - : annConfig.getDimension(); + ? knnMappingConfig.getDimension() / 8 + : knnMappingConfig.getDimension(); final VectorEncoding encoding = mappedFieldType.vectorDataType == VectorDataType.FLOAT ? VectorEncoding.FLOAT32 : VectorEncoding.BYTE; diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModeBasedResolver.java b/src/main/java/org/opensearch/knn/index/mapper/ModeBasedResolver.java deleted file mode 100644 index 2a0c8ef46..000000000 --- a/src/main/java/org/opensearch/knn/index/mapper/ModeBasedResolver.java +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.mapper; - -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.engine.MethodComponentContext; -import org.opensearch.knn.index.engine.faiss.QFrameBitEncoder; -import org.opensearch.knn.index.query.rescore.RescoreContext; - -import java.util.Locale; -import java.util.Map; -import java.util.Set; - -import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; -import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; -import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST_DEFAULT; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_DEFAULT; - -/** - * Class contains the logic to make parameter resolutions based on the {@link Mode} and {@link CompressionLevel}. - */ -public final class ModeBasedResolver { - - public static final ModeBasedResolver INSTANCE = new ModeBasedResolver(); - - private static final CompressionLevel DEFAULT_COMPRESSION_FOR_MODE_ON_DISK = CompressionLevel.x32; - private static final CompressionLevel DEFAULT_COMPRESSION_FOR_MODE_IN_MEMORY = CompressionLevel.x1; - public final static Set SUPPORTED_COMPRESSION_LEVELS = Set.of( - CompressionLevel.x1, - CompressionLevel.x2, - CompressionLevel.x8, - CompressionLevel.x16, - CompressionLevel.x32 - ); - - private ModeBasedResolver() {} - - /** - * Based on the provided {@link Mode} and {@link CompressionLevel}, resolve to a {@link KNNMethodContext} - * - * @param mode {@link Mode} - * @param compressionLevel {@link CompressionLevel} - * @param requiresTraining whether config requires trianing - * @return {@link KNNMethodContext} - */ - public KNNMethodContext resolveKNNMethodContext( - Mode mode, - CompressionLevel compressionLevel, - boolean requiresTraining, - SpaceType spaceType - ) { - if (requiresTraining) { - return resolveWithTraining(mode, compressionLevel, spaceType); - } - return resolveWithoutTraining(mode, compressionLevel, spaceType); - } - - private KNNMethodContext resolveWithoutTraining(Mode mode, CompressionLevel compressionLevel, final SpaceType spaceType) { - CompressionLevel resolvedCompressionLevel = resolveCompressionLevel(mode, compressionLevel); - MethodComponentContext encoderContext = resolveEncoder(resolvedCompressionLevel); - - KNNEngine knnEngine = Mode.ON_DISK == mode || encoderContext != null ? KNNEngine.FAISS : KNNEngine.DEFAULT; - - if (encoderContext != null) { - return new KNNMethodContext( - knnEngine, - spaceType, - new MethodComponentContext( - METHOD_HNSW, - Map.of( - METHOD_PARAMETER_M, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, - METHOD_PARAMETER_EF_CONSTRUCTION, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, - METHOD_PARAMETER_EF_SEARCH, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, - METHOD_ENCODER_PARAMETER, - encoderContext - ) - ) - ); - } - - if (knnEngine == KNNEngine.FAISS) { - return new KNNMethodContext( - knnEngine, - spaceType, - new MethodComponentContext( - METHOD_HNSW, - Map.of( - METHOD_PARAMETER_M, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, - METHOD_PARAMETER_EF_CONSTRUCTION, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, - METHOD_PARAMETER_EF_SEARCH, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH - ) - ) - ); - } - - return new KNNMethodContext( - knnEngine, - spaceType, - new MethodComponentContext( - METHOD_HNSW, - Map.of( - METHOD_PARAMETER_M, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, - METHOD_PARAMETER_EF_CONSTRUCTION, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION - ) - ) - ); - } - - private KNNMethodContext resolveWithTraining(Mode mode, CompressionLevel compressionLevel, SpaceType spaceType) { - CompressionLevel resolvedCompressionLevel = resolveCompressionLevel(mode, compressionLevel); - MethodComponentContext encoderContext = resolveEncoder(resolvedCompressionLevel); - if (encoderContext != null) { - return new KNNMethodContext( - KNNEngine.FAISS, - spaceType, - new MethodComponentContext( - METHOD_IVF, - Map.of( - METHOD_PARAMETER_NLIST, - METHOD_PARAMETER_NLIST_DEFAULT, - METHOD_PARAMETER_NPROBES, - METHOD_PARAMETER_NPROBES_DEFAULT, - METHOD_ENCODER_PARAMETER, - encoderContext - ) - ) - ); - } - - return new KNNMethodContext( - KNNEngine.FAISS, - spaceType, - new MethodComponentContext( - METHOD_IVF, - Map.of(METHOD_PARAMETER_NLIST, METHOD_PARAMETER_NLIST_DEFAULT, METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NPROBES_DEFAULT) - ) - ); - } - - /** - * Resolves the rescore context give the {@link Mode} and {@link CompressionLevel} - * - * @param mode {@link Mode} - * @param compressionLevel {@link CompressionLevel} - * @return {@link RescoreContext} - */ - public RescoreContext resolveRescoreContext(Mode mode, CompressionLevel compressionLevel) { - CompressionLevel resolvedCompressionLevel = resolveCompressionLevel(mode, compressionLevel); - return resolvedCompressionLevel.getDefaultRescoreContext(); - } - - private CompressionLevel resolveCompressionLevel(Mode mode, CompressionLevel compressionLevel) { - if (CompressionLevel.isConfigured(compressionLevel)) { - return compressionLevel; - } - - if (mode == Mode.ON_DISK) { - return DEFAULT_COMPRESSION_FOR_MODE_ON_DISK; - } - - return DEFAULT_COMPRESSION_FOR_MODE_IN_MEMORY; - } - - private MethodComponentContext resolveEncoder(CompressionLevel compressionLevel) { - if (CompressionLevel.isConfigured(compressionLevel) == false) { - throw new IllegalStateException("Compression level needs to be configured"); - } - - if (SUPPORTED_COMPRESSION_LEVELS.contains(compressionLevel) == false) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "Unsupported compression level: \"[%s]\"", compressionLevel.getName()) - ); - } - - if (compressionLevel == CompressionLevel.x1) { - return null; - } - - if (compressionLevel == CompressionLevel.x2) { - return new MethodComponentContext(ENCODER_SQ, Map.of(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16, FAISS_SQ_CLIP, true)); - } - - return new MethodComponentContext( - QFrameBitEncoder.NAME, - Map.of(QFrameBitEncoder.BITCOUNT_PARAM, compressionLevel.numBitsForFloat32()) - ); - } - -} diff --git a/src/main/java/org/opensearch/knn/index/mapper/OriginalMappingParameters.java b/src/main/java/org/opensearch/knn/index/mapper/OriginalMappingParameters.java index 77bf07a90..340c450ee 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/OriginalMappingParameters.java +++ b/src/main/java/org/opensearch/knn/index/mapper/OriginalMappingParameters.java @@ -37,6 +37,8 @@ public final class OriginalMappingParameters { // (https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L322-L324). // So, what we do is pass in a "resolvedKNNMethodContext" to ensure we track this resolveKnnMethodContext. // A similar approach was taken for https://github.com/opendistro-for-elasticsearch/k-NN/issues/288 + // + // In almost all cases except when dealing with the mapping, the resolved context should be used @Setter private KNNMethodContext resolvedKnnMethodContext; private final String mode; diff --git a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java index 02aa1e954..02f57660b 100644 --- a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java @@ -34,6 +34,7 @@ import java.util.Set; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; +import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; @@ -173,7 +174,9 @@ public static ValidationException validateKnnField( if (parameters != null && parameters.containsKey(KNNConstants.METHOD_ENCODER_PARAMETER)) { MethodComponentContext encoder = (MethodComponentContext) parameters.get(KNNConstants.METHOD_ENCODER_PARAMETER); - if (encoder != null && VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS.contains(trainRequestVectorDataType)) { + if (encoder != null + && VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS.contains(trainRequestVectorDataType) + && ENCODER_FLAT.equals(encoder.getName()) == false) { exception.addValidationError( String.format( Locale.ROOT, diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index 0f7e8523b..71f7201de 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -19,6 +19,7 @@ import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.SpaceTypeResolver; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.indices.ModelUtil; @@ -132,8 +133,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr } } - // Check that these parameters get set - ensureAtleasOneSet(KNN_METHOD, knnMethodContext, MODE_PARAMETER, mode, COMPRESSION_LEVEL_PARAMETER, compressionLevel); + ensureAtleastOneSet(KNN_METHOD, knnMethodContext, MODE_PARAMETER, mode, COMPRESSION_LEVEL_PARAMETER, compressionLevel); ensureMutualExclusion(KNN_METHOD, knnMethodContext, MODE_PARAMETER, mode); ensureMutualExclusion(KNN_METHOD, knnMethodContext, COMPRESSION_LEVEL_PARAMETER, compressionLevel); @@ -160,11 +160,12 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr vectorDataType, VectorDataType.FLOAT.getValue() ); - resolveSpaceTypeAndSetInKNNMethodContext(topLevelSpaceType, knnMethodContext); - // if KNNMethodContext was not null then spaceTypes we should fix the space type if it is not set. - if (knnMethodContext == null && topLevelSpaceType == SpaceType.UNDEFINED) { - topLevelSpaceType = SpaceType.DEFAULT; - } + SpaceType resolvedSpaceType = SpaceTypeResolver.INSTANCE.resolveSpaceType( + knnMethodContext, + vectorDataType, + topLevelSpaceType.getValue() + ); + setSpaceType(knnMethodContext, resolvedSpaceType); TrainingModelRequest trainingModelRequest = new TrainingModelRequest( modelId, knnMethodContext, @@ -176,7 +177,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr vectorDataType, Mode.fromName(mode), CompressionLevel.fromName(compressionLevel), - topLevelSpaceType + resolvedSpaceType ); if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) { @@ -217,26 +218,11 @@ private boolean ensureSpaceTypeNotSet(SpaceType spaceType) { return true; } - private void resolveSpaceTypeAndSetInKNNMethodContext(SpaceType topLevelSpaceType, KNNMethodContext knnMethodContext) { - // First check if KNNMethodContext is not null as it can be null - if (knnMethodContext != null) { - // if space type is not provided by user then it will undefined - if (knnMethodContext.getSpaceType() == SpaceType.UNDEFINED) { - // fix the top level spaceType if it is undefined - if (topLevelSpaceType == SpaceType.UNDEFINED) { - topLevelSpaceType = SpaceType.DEFAULT; - } - // set the space type now in KNNMethodContext - knnMethodContext.setSpaceType(topLevelSpaceType); - } else { - // if spaceType is set at 2 places lets ensure that we validate those cases and throw error - if (topLevelSpaceType != SpaceType.UNDEFINED) { - throw new IllegalArgumentException( - "Top Level spaceType and space type in method both are set. Set space type at 1 place." - ); - } - } + private void setSpaceType(KNNMethodContext knnMethodContext, SpaceType resolvedSpaceType) { + if (knnMethodContext == null) { + return; } + knnMethodContext.setSpaceType(resolvedSpaceType); } private void ensureIfSetThenEquals( @@ -263,8 +249,11 @@ private void ensureIfSetThenEquals( } } - private void ensureAtleasOneSet(String fieldNameA, Object valueA, String fieldNameB, Object valueB, String fieldNameC, Object valueC) { + private void ensureAtleastOneSet(String fieldNameA, Object valueA, String fieldNameB, Object valueB, String fieldNameC, Object valueC) { if (valueA == DEFAULT_NOT_SET_OBJECT_VALUE && valueB == DEFAULT_NOT_SET_OBJECT_VALUE && valueC == DEFAULT_NOT_SET_OBJECT_VALUE) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "At least \"[%s]\", \"[%s]\" or \"[%s]\" needs to be set", fieldNameA, fieldNameB, fieldNameC) + ); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index df24baf0e..9906ab490 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -22,10 +22,12 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.ResolvedMethodContext; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.Mode; -import org.opensearch.knn.index.mapper.ModeBasedResolver; +import org.opensearch.knn.index.engine.EngineResolver; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; @@ -116,6 +118,7 @@ public TrainingModelRequest( this.preferredNodeId = preferredNodeId; this.description = description; this.vectorDataType = vectorDataType; + this.mode = mode; // Set these as defaults initially. If call wants to override them, they can use the setters. this.maximumVectorCount = Integer.MAX_VALUE; // By default, get all vectors in the index @@ -123,10 +126,6 @@ public TrainingModelRequest( // Training data size in kilobytes. By default, this is invalid (it cant have negative kb). It eventually gets // calculated in transit. A user cannot set this value directly. - this.trainingDataSizeInKB = -1; - this.mode = mode; - this.compressionLevel = compressionLevel; - this.knnMethodConfigContext = KNNMethodConfigContext.builder() .vectorDataType(vectorDataType) .dimension(dimension) @@ -135,11 +134,11 @@ public TrainingModelRequest( .mode(mode) .build(); - if (knnMethodContext == null && (Mode.isConfigured(mode) || CompressionLevel.isConfigured(compressionLevel))) { - this.knnMethodContext = ModeBasedResolver.INSTANCE.resolveKNNMethodContext(mode, compressionLevel, true, spaceType); - } else { - this.knnMethodContext = knnMethodContext; - } + KNNEngine knnEngine = EngineResolver.INSTANCE.resolveEngine(knnMethodConfigContext, knnMethodContext, true); + ResolvedMethodContext resolvedMethodContext = knnEngine.resolveMethod(knnMethodContext, knnMethodConfigContext, true, spaceType); + this.knnMethodContext = resolvedMethodContext.getKnnMethodContext(); + this.compressionLevel = resolvedMethodContext.getCompressionLevel(); + this.knnMethodConfigContext.setCompressionLevel(resolvedMethodContext.getCompressionLevel()); } /** diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 6ef7373d2..21b3298be 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -35,6 +35,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; /** * Base class for integration tests for KNN plugin. Contains several methods for testing KNN ES functionality. @@ -106,6 +107,11 @@ public static KNNMethodContext getDefaultKNNMethodContext() { return new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); } + public static KNNMethodContext getDefaultKNNMethodContextForModel() { + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_IVF, Collections.emptyMap()); + return new KNNMethodContext(KNNEngine.FAISS, SpaceType.DEFAULT, methodComponentContext); + } + public static KNNMethodContext getDefaultByteKNNMethodContext() { MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); return new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); diff --git a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java index ccaeb19a5..95d4b68a5 100644 --- a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java @@ -168,5 +168,15 @@ public Boolean isInitialized() { public void setInitialized(Boolean isInitialized) { } + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + SpaceType spaceType + ) { + return null; + } } } diff --git a/src/test/java/org/opensearch/knn/index/engine/AbstractMethodResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/AbstractMethodResolverTests.java new file mode 100644 index 000000000..f21459246 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/AbstractMethodResolverTests.java @@ -0,0 +1,158 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +public class AbstractMethodResolverTests extends KNNTestCase { + + private final static String ENCODER_NAME = "test"; + private final static CompressionLevel DEFAULT_COMPRESSION = CompressionLevel.x8; + + private final static AbstractMethodResolver TEST_RESOLVER = new AbstractMethodResolver() { + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + SpaceType spaceType + ) { + return null; + } + }; + + private final static Encoder TEST_ENCODER = new Encoder() { + + @Override + public MethodComponent getMethodComponent() { + return MethodComponent.Builder.builder(ENCODER_NAME).build(); + } + + @Override + public CompressionLevel calculateCompressionLevel( + MethodComponentContext encoderContext, + KNNMethodConfigContext knnMethodConfigContext + ) { + return DEFAULT_COMPRESSION; + } + }; + + private final static Map ENCODER_MAP = Map.of(ENCODER_NAME, TEST_ENCODER); + + public void testResolveCompressionLevelFromMethodContext() { + assertEquals( + CompressionLevel.x1, + TEST_RESOLVER.resolveCompressionLevelFromMethodContext( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, MethodComponentContext.EMPTY), + KNNMethodConfigContext.builder().build(), + ENCODER_MAP + ) + ); + assertEquals( + DEFAULT_COMPRESSION, + TEST_RESOLVER.resolveCompressionLevelFromMethodContext( + new KNNMethodContext( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + new MethodComponentContext( + METHOD_HNSW, + Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_NAME, Map.of())) + ) + ), + KNNMethodConfigContext.builder().build(), + ENCODER_MAP + ) + ); + } + + public void testIsEncoderSpecified() { + assertFalse(TEST_RESOLVER.isEncoderSpecified(null)); + assertFalse( + TEST_RESOLVER.isEncoderSpecified(new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, MethodComponentContext.EMPTY)) + ); + assertFalse( + TEST_RESOLVER.isEncoderSpecified( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, new MethodComponentContext(METHOD_HNSW, Map.of())) + ) + ); + assertTrue( + TEST_RESOLVER.isEncoderSpecified( + new KNNMethodContext( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, Map.of(METHOD_ENCODER_PARAMETER, "test")) + ) + ) + ); + } + + public void testShouldEncoderBeResolved() { + assertFalse( + TEST_RESOLVER.shouldEncoderBeResolved( + new KNNMethodContext( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, Map.of(METHOD_ENCODER_PARAMETER, "test")) + ), + KNNMethodConfigContext.builder().build() + ) + ); + assertFalse( + TEST_RESOLVER.shouldEncoderBeResolved(null, KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x1).build()) + ); + assertFalse( + TEST_RESOLVER.shouldEncoderBeResolved( + null, + KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x1).mode(Mode.ON_DISK).build() + ) + ); + assertFalse( + TEST_RESOLVER.shouldEncoderBeResolved( + null, + KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.NOT_CONFIGURED).mode(Mode.IN_MEMORY).build() + ) + ); + assertFalse( + TEST_RESOLVER.shouldEncoderBeResolved( + null, + KNNMethodConfigContext.builder() + .compressionLevel(CompressionLevel.NOT_CONFIGURED) + .mode(Mode.ON_DISK) + .vectorDataType(VectorDataType.BINARY) + .build() + ) + ); + assertTrue( + TEST_RESOLVER.shouldEncoderBeResolved( + null, + KNNMethodConfigContext.builder() + .compressionLevel(CompressionLevel.NOT_CONFIGURED) + .mode(Mode.ON_DISK) + .vectorDataType(VectorDataType.FLOAT) + .build() + ) + ); + assertTrue( + TEST_RESOLVER.shouldEncoderBeResolved( + null, + KNNMethodConfigContext.builder() + .compressionLevel(CompressionLevel.x32) + .mode(Mode.ON_DISK) + .vectorDataType(VectorDataType.FLOAT) + .build() + ) + ); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java new file mode 100644 index 000000000..df195883a --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java @@ -0,0 +1,152 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +public class EngineResolverTests extends KNNTestCase { + + private static final EngineResolver ENGINE_RESOLVER = EngineResolver.INSTANCE; + + public void testResolveEngine_whenEngineSpecifiedInMethod_thenThatEngine() { + assertEquals( + KNNEngine.LUCENE, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().build(), + new KNNMethodContext(KNNEngine.LUCENE, SpaceType.DEFAULT, MethodComponentContext.EMPTY), + false + ) + ); + } + + public void testResolveEngine_whenRequiresTraining_thenFaiss() { + assertEquals(KNNEngine.FAISS, ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().build(), null, true)); + } + + public void testResolveEngine_whenModeAndCompressionAreFalse_thenDefault() { + assertEquals(KNNEngine.DEFAULT, ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().build(), null, false)); + assertEquals( + KNNEngine.DEFAULT, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().build(), + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY, false), + false + ) + ); + } + + public void testResolveEngine_whenModeSpecifiedAndCompressionIsNotSpecified_thenDefault() { + assertEquals(KNNEngine.DEFAULT, ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().build(), null, false)); + assertEquals( + KNNEngine.DEFAULT, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).build(), + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY, false), + false + ) + ); + } + + public void testResolveEngine_whenCompressionIs1x_thenEngineBasedOnMode() { + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x1).build(), + null, + false + ) + ); + assertEquals( + KNNEngine.DEFAULT, + ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x1).build(), null, false) + ); + } + + public void testResolveEngine_whenCompressionIs4x_thenEngineIsLucene() { + assertEquals( + KNNEngine.LUCENE, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x4).build(), + null, + false + ) + ); + assertEquals( + KNNEngine.LUCENE, + ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x4).build(), null, false) + ); + } + + public void testResolveEngine_whenConfiguredForBQ_thenEngineIsFaiss() { + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x2).build(), + null, + false + ) + ); + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).compressionLevel(CompressionLevel.x2).build(), + null, + false + ) + ); + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x8).build(), + null, + false + ) + ); + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).compressionLevel(CompressionLevel.x8).build(), + null, + false + ) + ); + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x16).build(), + null, + false + ) + ); + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).compressionLevel(CompressionLevel.x16).build(), + null, + false + ) + ); + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x32).build(), + null, + false + ) + ); + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).compressionLevel(CompressionLevel.x32).build(), + null, + false + ) + ); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/NativeLibraryTests.java b/src/test/java/org/opensearch/knn/index/engine/NativeLibraryTests.java index 243e9a3c1..c1fbe4aa5 100644 --- a/src/test/java/org/opensearch/knn/index/engine/NativeLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/NativeLibraryTests.java @@ -73,5 +73,15 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { return 0.0f; } + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + SpaceType spaceType + ) { + return null; + } } } diff --git a/src/test/java/org/opensearch/knn/index/engine/SpaceTypeResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/SpaceTypeResolverTests.java new file mode 100644 index 000000000..99fc98c9e --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/SpaceTypeResolverTests.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import lombok.SneakyThrows; +import org.opensearch.index.mapper.MapperParsingException; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; + +public class SpaceTypeResolverTests extends KNNTestCase { + + private static final SpaceTypeResolver SPACE_TYPE_RESOLVER = SpaceTypeResolver.INSTANCE; + + public void testResolveSpaceType_whenNoConfigProvided_thenFallbackToVectorDataType() { + assertEquals(SpaceType.DEFAULT, SPACE_TYPE_RESOLVER.resolveSpaceType(null, VectorDataType.FLOAT, "")); + assertEquals(SpaceType.DEFAULT, SPACE_TYPE_RESOLVER.resolveSpaceType(null, VectorDataType.BYTE, "")); + assertEquals( + SpaceType.DEFAULT, + SPACE_TYPE_RESOLVER.resolveSpaceType( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY), + VectorDataType.FLOAT, + "" + ) + ); + assertEquals(SpaceType.DEFAULT_BINARY, SPACE_TYPE_RESOLVER.resolveSpaceType(null, VectorDataType.BINARY, "")); + assertEquals( + SpaceType.DEFAULT_BINARY, + SPACE_TYPE_RESOLVER.resolveSpaceType( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY), + VectorDataType.BINARY, + "" + ) + ); + } + + @SneakyThrows + public void testResolveSpaceType_whenMethodSpaceTypeAndTopLevelSpecified_thenThrowIfConflict() { + expectThrows( + MapperParsingException.class, + () -> SPACE_TYPE_RESOLVER.resolveSpaceType( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.L2, MethodComponentContext.EMPTY), + VectorDataType.FLOAT, + SpaceType.INNER_PRODUCT.getValue() + ) + ); + assertEquals( + SpaceType.DEFAULT, + SPACE_TYPE_RESOLVER.resolveSpaceType( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, MethodComponentContext.EMPTY), + VectorDataType.FLOAT, + SpaceType.DEFAULT.getValue() + ) + ); + assertEquals( + SpaceType.DEFAULT, + SPACE_TYPE_RESOLVER.resolveSpaceType( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, MethodComponentContext.EMPTY), + VectorDataType.FLOAT, + SpaceType.UNDEFINED.getValue() + ) + ); + assertEquals( + SpaceType.DEFAULT, + SPACE_TYPE_RESOLVER.resolveSpaceType( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY), + VectorDataType.FLOAT, + SpaceType.DEFAULT.getValue() + ) + ); + assertEquals( + SpaceType.DEFAULT, + SPACE_TYPE_RESOLVER.resolveSpaceType( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY), + VectorDataType.FLOAT, + SpaceType.UNDEFINED.getValue() + ) + ); + } + + @SneakyThrows + public void testResolveSpaceType_whenSpaceTypeSpecifiedOnce_thenReturnValue() { + assertEquals( + SpaceType.L1, + SPACE_TYPE_RESOLVER.resolveSpaceType( + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.L1, MethodComponentContext.EMPTY), + VectorDataType.FLOAT, + "" + ) + ); + assertEquals( + SpaceType.INNER_PRODUCT, + SPACE_TYPE_RESOLVER.resolveSpaceType(null, VectorDataType.FLOAT, SpaceType.INNER_PRODUCT.getValue()) + ); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoderTests.java new file mode 100644 index 000000000..3f7dd9dcd --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissHNSWPQEncoderTests.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.mapper.CompressionLevel; + +public class FaissHNSWPQEncoderTests extends KNNTestCase { + public void testCalculateCompressionLevel() { + FaissHNSWPQEncoder encoder = new FaissHNSWPQEncoder(); + assertEquals(CompressionLevel.NOT_CONFIGURED, encoder.calculateCompressionLevel(null, null)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoderTests.java new file mode 100644 index 000000000..35b7a64ab --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissIVFPQEncoderTests.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.mapper.CompressionLevel; + +public class FaissIVFPQEncoderTests extends KNNTestCase { + public void testCalculateCompressionLevel() { + FaissIVFPQEncoder encoder = new FaissIVFPQEncoder(); + assertEquals(CompressionLevel.NOT_CONFIGURED, encoder.calculateCompressionLevel(null, null)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java new file mode 100644 index 000000000..ad466d4bb --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java @@ -0,0 +1,246 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import org.opensearch.Version; +import org.opensearch.common.ValidationException; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.engine.MethodResolver; +import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +public class FaissMethodResolverTests extends KNNTestCase { + + MethodResolver TEST_RESOLVER = new FaissMethodResolver(); + + public void testResolveMethod_whenValid_thenResolve() { + ResolvedMethodContext resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + false, + SpaceType.INNER_PRODUCT + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x1, SpaceType.INNER_PRODUCT, ENCODER_FLAT, false); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .mode(Mode.ON_DISK) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.INNER_PRODUCT + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x32, SpaceType.INNER_PRODUCT, QFrameBitEncoder.NAME, true); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .mode(Mode.ON_DISK) + .compressionLevel(CompressionLevel.x16) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.INNER_PRODUCT + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x16, SpaceType.INNER_PRODUCT, QFrameBitEncoder.NAME, true); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .compressionLevel(CompressionLevel.x16) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.INNER_PRODUCT + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x16, SpaceType.INNER_PRODUCT, QFrameBitEncoder.NAME, true); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + new KNNMethodContext( + KNNEngine.FAISS, + SpaceType.L2, + new MethodComponentContext( + METHOD_HNSW, + Map.of( + METHOD_ENCODER_PARAMETER, + new MethodComponentContext( + QFrameBitEncoder.NAME, + Map.of(QFrameBitEncoder.BITCOUNT_PARAM, CompressionLevel.x8.numBitsForFloat32()) + ) + ) + ) + ), + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .mode(Mode.ON_DISK) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.L2 + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x8, SpaceType.L2, QFrameBitEncoder.NAME, true); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + new KNNMethodContext( + KNNEngine.FAISS, + SpaceType.L2, + new MethodComponentContext( + METHOD_HNSW, + Map.of( + METHOD_ENCODER_PARAMETER, + new MethodComponentContext( + QFrameBitEncoder.NAME, + Map.of(QFrameBitEncoder.BITCOUNT_PARAM, CompressionLevel.x8.numBitsForFloat32()) + ) + ) + ) + ), + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + false, + SpaceType.L2 + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x8, SpaceType.L2, QFrameBitEncoder.NAME, true); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, new MethodComponentContext(METHOD_HNSW, Map.of())), + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + false, + SpaceType.L2 + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x1, SpaceType.L2, ENCODER_FLAT, false); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, new MethodComponentContext(METHOD_HNSW, Map.of())), + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.BINARY).versionCreated(Version.CURRENT).build(), + false, + SpaceType.L2 + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x1, SpaceType.L2, ENCODER_FLAT, false); + } + + private void validateResolveMethodContext( + ResolvedMethodContext resolvedMethodContext, + CompressionLevel expectedCompression, + SpaceType expectedSpaceType, + String expectedEncoderName, + boolean checkBitsEncoderParam + ) { + assertEquals(expectedCompression, resolvedMethodContext.getCompressionLevel()); + assertEquals(KNNEngine.FAISS, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(expectedSpaceType, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals( + expectedEncoderName, + ((MethodComponentContext) resolvedMethodContext.getKnnMethodContext() + .getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER)).getName() + ); + if (checkBitsEncoderParam) { + assertEquals( + expectedCompression.numBitsForFloat32(), + ((MethodComponentContext) resolvedMethodContext.getKnnMethodContext() + .getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER)).getParameters().get(QFrameBitEncoder.BITCOUNT_PARAM) + ); + } + + } + + public void testResolveMethod_whenInvalid_thenThrow() { + // Invalid compression + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .compressionLevel(CompressionLevel.x4) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.L2 + ) + ); + + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.BINARY) + .compressionLevel(CompressionLevel.x4) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.L2 + ) + ); + + // Invalid spec ondisk and compression is 1 + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .mode(Mode.ON_DISK) + .compressionLevel(CompressionLevel.x1) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.L2 + ) + ); + + // Invalid compression conflict + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + new KNNMethodContext( + KNNEngine.FAISS, + SpaceType.INNER_PRODUCT, + new MethodComponentContext( + METHOD_HNSW, + Map.of( + METHOD_ENCODER_PARAMETER, + new MethodComponentContext( + QFrameBitEncoder.NAME, + Map.of(QFrameBitEncoder.BITCOUNT_PARAM, CompressionLevel.x32.numBitsForFloat32()) + ) + ) + ) + ), + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .mode(Mode.ON_DISK) + .compressionLevel(CompressionLevel.x8) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.INNER_PRODUCT + ) + + ); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoderTests.java new file mode 100644 index 000000000..3905158a2 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoderTests.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.faiss; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.mapper.CompressionLevel; + +public class FaissSQEncoderTests extends KNNTestCase { + public void testCalculateCompressionLevel() { + FaissSQEncoder encoder = new FaissSQEncoder(); + assertEquals(CompressionLevel.x2, encoder.calculateCompressionLevel(null, null)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoderTests.java index 7457b49aa..e926916af 100644 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoderTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/QFrameBitEncoderTests.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.Version; +import org.opensearch.common.ValidationException; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; @@ -14,10 +15,16 @@ import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import java.util.HashMap; +import java.util.Map; + import static org.opensearch.knn.common.KNNConstants.FAISS_FLAT_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.index.engine.faiss.QFrameBitEncoder.BITCOUNT_PARAM; public class QFrameBitEncoderTests extends KNNTestCase { @@ -121,4 +128,40 @@ public void testEstimateOverheadInKB() { .estimateOverheadInKB(new MethodComponentContext(QFrameBitEncoder.NAME, ImmutableMap.of(BITCOUNT_PARAM, 4)), 8) ); } + + public void testCalculateCompressionLevel() { + QFrameBitEncoder encoder = new QFrameBitEncoder(); + assertEquals( + CompressionLevel.x32, + encoder.calculateCompressionLevel(generateMethodComponentContext(CompressionLevel.x32.numBitsForFloat32()), null) + ); + assertEquals( + CompressionLevel.x16, + encoder.calculateCompressionLevel(generateMethodComponentContext(CompressionLevel.x16.numBitsForFloat32()), null) + ); + assertEquals( + CompressionLevel.x8, + encoder.calculateCompressionLevel(generateMethodComponentContext(CompressionLevel.x8.numBitsForFloat32()), null) + ); + assertEquals( + CompressionLevel.NOT_CONFIGURED, + encoder.calculateCompressionLevel( + new MethodComponentContext( + METHOD_HNSW, + new HashMap<>(Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(QFrameBitEncoder.NAME, Map.of()))) + ), + null + ) + ); + + expectThrows( + ValidationException.class, + () -> encoder.calculateCompressionLevel(generateMethodComponentContext(CompressionLevel.x4.numBitsForFloat32()), null) + ); + expectThrows(ValidationException.class, () -> encoder.calculateCompressionLevel(generateMethodComponentContext(-1), null)); + } + + private MethodComponentContext generateMethodComponentContext(int bitCount) { + return new MethodComponentContext(QFrameBitEncoder.NAME, Map.of(BITCOUNT_PARAM, bitCount)); + } } diff --git a/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolverTests.java new file mode 100644 index 000000000..833d83135 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolverTests.java @@ -0,0 +1,212 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.lucene; + +import org.opensearch.Version; +import org.opensearch.common.ValidationException; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.engine.MethodResolver; +import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +public class LuceneMethodResolverTests extends KNNTestCase { + MethodResolver TEST_RESOLVER = new LuceneMethodResolver(); + + public void testResolveMethod_whenValid_thenResolve() { + ResolvedMethodContext resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + false, + SpaceType.INNER_PRODUCT + ); + assertEquals(METHOD_HNSW, resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getName()); + assertFalse(resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + assertEquals(KNNEngine.LUCENE, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(SpaceType.INNER_PRODUCT, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals(CompressionLevel.x1, resolvedMethodContext.getCompressionLevel()); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(Version.CURRENT) + .mode(Mode.ON_DISK) + .build(), + false, + SpaceType.INNER_PRODUCT + ); + assertEquals(METHOD_HNSW, resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getName()); + assertFalse(resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + assertTrue( + resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER) + ); + assertEquals(KNNEngine.LUCENE, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(SpaceType.INNER_PRODUCT, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals(CompressionLevel.x4, resolvedMethodContext.getCompressionLevel()); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(Version.CURRENT) + .compressionLevel(CompressionLevel.x4) + .build(), + false, + SpaceType.INNER_PRODUCT + ); + assertEquals(METHOD_HNSW, resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getName()); + assertFalse(resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + assertTrue( + resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER) + ); + assertEquals(KNNEngine.LUCENE, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(SpaceType.INNER_PRODUCT, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals(CompressionLevel.x4, resolvedMethodContext.getCompressionLevel()); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_HNSW, Map.of()) + ); + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + knnMethodContext, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(Version.CURRENT) + .mode(Mode.ON_DISK) + .build(), + false, + SpaceType.INNER_PRODUCT + ); + assertEquals(METHOD_HNSW, resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getName()); + assertFalse(resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + assertTrue( + resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER) + ); + assertEquals(KNNEngine.LUCENE, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(SpaceType.INNER_PRODUCT, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals(CompressionLevel.x4, resolvedMethodContext.getCompressionLevel()); + assertNotEquals(knnMethodContext, resolvedMethodContext.getKnnMethodContext()); + + knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_HNSW, Map.of()) + ); + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + knnMethodContext, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(Version.CURRENT) + .compressionLevel(CompressionLevel.x4) + .build(), + false, + SpaceType.INNER_PRODUCT + ); + assertEquals(METHOD_HNSW, resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getName()); + assertFalse(resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + assertTrue( + resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER) + ); + assertEquals(KNNEngine.LUCENE, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(SpaceType.INNER_PRODUCT, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals(CompressionLevel.x4, resolvedMethodContext.getCompressionLevel()); + assertNotEquals(knnMethodContext, resolvedMethodContext.getKnnMethodContext()); + + knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_HNSW, Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_SQ, Map.of()))) + ); + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + knnMethodContext, + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + false, + SpaceType.INNER_PRODUCT + ); + assertEquals(METHOD_HNSW, resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getName()); + assertFalse(resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + assertTrue( + resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER) + ); + assertEquals(KNNEngine.LUCENE, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(SpaceType.INNER_PRODUCT, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals(CompressionLevel.x4, resolvedMethodContext.getCompressionLevel()); + assertNotEquals(knnMethodContext, resolvedMethodContext.getKnnMethodContext()); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.BYTE).versionCreated(Version.CURRENT).build(), + false, + SpaceType.INNER_PRODUCT + ); + assertEquals(METHOD_HNSW, resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getName()); + assertFalse(resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + assertFalse( + resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER) + ); + assertEquals(KNNEngine.LUCENE, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(SpaceType.INNER_PRODUCT, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals(CompressionLevel.x1, resolvedMethodContext.getCompressionLevel()); + } + + public void testResolveMethod_whenInvalid_thenThrow() { + // Invalid training context + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + true, + SpaceType.L2 + ) + ); + + // Invalid compression + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .compressionLevel(CompressionLevel.x32) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.L2 + ) + ); + + // Invalid spec ondisk and compression is 1 + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .mode(Mode.ON_DISK) + .compressionLevel(CompressionLevel.x1) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.L2 + ) + ); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoderTests.java new file mode 100644 index 000000000..139f96e8b --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoderTests.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.lucene; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.mapper.CompressionLevel; + +public class LuceneSQEncoderTests extends KNNTestCase { + public void testCalculateCompressionLevel() { + LuceneSQEncoder encoder = new LuceneSQEncoder(); + assertEquals(CompressionLevel.x4, encoder.calculateCompressionLevel(null, null)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/nmslib/NmslibMethodResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/nmslib/NmslibMethodResolverTests.java new file mode 100644 index 000000000..065e0e378 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/engine/nmslib/NmslibMethodResolverTests.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.nmslib; + +import org.opensearch.Version; +import org.opensearch.common.ValidationException; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.engine.MethodResolver; +import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; + +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +public class NmslibMethodResolverTests extends KNNTestCase { + + MethodResolver TEST_RESOLVER = new NmslibMethodResolver(); + + public void testResolveMethod_whenValid_thenResolve() { + // No configuration passed in + ResolvedMethodContext resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + false, + SpaceType.INNER_PRODUCT + ); + assertEquals(METHOD_HNSW, resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getName()); + assertFalse(resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + assertEquals(KNNEngine.NMSLIB, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(SpaceType.INNER_PRODUCT, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals(CompressionLevel.x1, resolvedMethodContext.getCompressionLevel()); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.NMSLIB, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_HNSW, Map.of()) + ); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + knnMethodContext, + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + false, + SpaceType.INNER_PRODUCT + ); + assertEquals(METHOD_HNSW, resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getName()); + assertFalse(resolvedMethodContext.getKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + assertEquals(KNNEngine.NMSLIB, resolvedMethodContext.getKnnMethodContext().getKnnEngine()); + assertEquals(SpaceType.INNER_PRODUCT, resolvedMethodContext.getKnnMethodContext().getSpaceType()); + assertEquals(CompressionLevel.x1, resolvedMethodContext.getCompressionLevel()); + assertNotEquals(knnMethodContext, resolvedMethodContext.getKnnMethodContext()); + } + + public void testResolveMethod_whenInvalid_thenThrow() { + // Invalid training context + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + true, + SpaceType.L2 + ) + ); + + // Invalid compression + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .compressionLevel(CompressionLevel.x8) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.L2 + ) + ); + + // Invalid mode + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .mode(Mode.ON_DISK) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.L2 + ) + ); + } +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 84cbf05dc..98bbf42ca 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -48,6 +48,7 @@ import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.engine.faiss.QFrameBitEncoder; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -93,6 +94,7 @@ @Log4j2 public class KNNVectorFieldMapperTests extends KNNTestCase { + private static final String TEST_INDEX_NAME = "test-index-name"; private static final String TEST_FIELD_NAME = "test-field-name"; private static final int TEST_DIMENSION = 17; @@ -1633,7 +1635,7 @@ public void testTypeParser_whenBinaryWithLegacyKNNEnabled_thenException() throws typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder), buildParserContext(indexName, settings)); }); - assertTrue(ex.getMessage(), ex.getMessage().contains("is not supported with")); + assertTrue(ex.getMessage(), ex.getMessage().contains("does not support space type")); } public void testBuild_whenInvalidCharsInFieldName_thenThrowException() { @@ -1653,6 +1655,345 @@ public void testBuild_whenInvalidCharsInFieldName_thenThrowException() { } } + public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOException { + int dimension = 16; + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + // Default to nmslib and ensure legacy is in use + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .endObject(); + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(xContentBuilder), + buildParserContext(TEST_INDEX_NAME, settings) + ); + assertNull(builder.getOriginalParameters().getKnnMethodContext()); + assertTrue(builder.getOriginalParameters().isLegacyMapping()); + validateBuilderAfterParsing( + builder, + KNNEngine.NMSLIB, + SpaceType.L2, + VectorDataType.FLOAT, + CompressionLevel.x1, + CompressionLevel.NOT_CONFIGURED, + Mode.NOT_CONFIGURED, + false + ); + + // If mode is in memory and 1x compression, again use default legacy + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x1.getName()) + .field(MODE_PARAMETER, Mode.IN_MEMORY.getName()) + .endObject(); + builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(xContentBuilder), + buildParserContext(TEST_INDEX_NAME, settings) + ); + assertNull(builder.getOriginalParameters().getKnnMethodContext()); + assertFalse(builder.getOriginalParameters().isLegacyMapping()); + validateBuilderAfterParsing( + builder, + KNNEngine.NMSLIB, + SpaceType.L2, + VectorDataType.FLOAT, + CompressionLevel.x1, + CompressionLevel.x1, + Mode.IN_MEMORY, + false + ); + + // Default on disk is faiss with 32x binary quant + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .endObject(); + + builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(xContentBuilder), + buildParserContext(TEST_INDEX_NAME, settings) + ); + validateBuilderAfterParsing( + builder, + KNNEngine.FAISS, + SpaceType.L2, + VectorDataType.FLOAT, + CompressionLevel.x32, + CompressionLevel.NOT_CONFIGURED, + Mode.ON_DISK, + true + ); + + // Ensure 2x does not use binary quantization + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x2.getName()) + .endObject(); + + builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(xContentBuilder), + buildParserContext(TEST_INDEX_NAME, settings) + ); + validateBuilderAfterParsing( + builder, + KNNEngine.FAISS, + SpaceType.L2, + VectorDataType.FLOAT, + CompressionLevel.x2, + CompressionLevel.x2, + Mode.NOT_CONFIGURED, + false + ); + + // For 8x ensure that it does use binary quantization + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x8.getName()) + .endObject(); + + builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(xContentBuilder), + buildParserContext(TEST_INDEX_NAME, settings) + ); + validateBuilderAfterParsing( + builder, + KNNEngine.FAISS, + SpaceType.L2, + VectorDataType.FLOAT, + CompressionLevel.x8, + CompressionLevel.x8, + Mode.ON_DISK, + true + ); + + // For 4x compression on disk, use Lucene + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x4.getName()) + .endObject(); + + builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(xContentBuilder), + buildParserContext(TEST_INDEX_NAME, settings) + ); + validateBuilderAfterParsing( + builder, + KNNEngine.LUCENE, + SpaceType.L2, + VectorDataType.FLOAT, + CompressionLevel.x4, + CompressionLevel.x4, + Mode.ON_DISK, + false + ); + + // For 4x compression in memory, use Lucene + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(MODE_PARAMETER, Mode.IN_MEMORY.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x4.getName()) + .endObject(); + + builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(xContentBuilder), + buildParserContext(TEST_INDEX_NAME, settings) + ); + validateBuilderAfterParsing( + builder, + KNNEngine.LUCENE, + SpaceType.L2, + VectorDataType.FLOAT, + CompressionLevel.x4, + CompressionLevel.x4, + Mode.IN_MEMORY, + false + ); + + // For override, ensure compression is correct + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, KNNEngine.FAISS) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, QFrameBitEncoder.NAME) + .startObject(PARAMETERS) + .field(QFrameBitEncoder.BITCOUNT_PARAM, CompressionLevel.x16.numBitsForFloat32()) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(xContentBuilder), + buildParserContext(TEST_INDEX_NAME, settings) + ); + validateBuilderAfterParsing( + builder, + KNNEngine.FAISS, + SpaceType.L2, + VectorDataType.FLOAT, + CompressionLevel.x16, + CompressionLevel.NOT_CONFIGURED, + Mode.NOT_CONFIGURED, + true + ); + + // Override with conflicting compression levels should fail + XContentBuilder invalidXContentBuilder1 = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x4.getName()) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, KNNEngine.FAISS) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, QFrameBitEncoder.NAME) + .startObject(PARAMETERS) + .field(QFrameBitEncoder.BITCOUNT_PARAM, CompressionLevel.x16.numBitsForFloat32()) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + expectThrows( + ValidationException.class, + () -> typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(invalidXContentBuilder1), + buildParserContext(TEST_INDEX_NAME, settings) + ) + ); + + // Invalid if vector data type is binary + XContentBuilder invalidXContentBuilder2 = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()) + .field(MODE_PARAMETER, Mode.IN_MEMORY.getName()) + .endObject(); + + expectThrows( + MapperParsingException.class, + () -> typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(invalidXContentBuilder2), + buildParserContext(TEST_INDEX_NAME, settings) + ) + ); + + // Invalid if engine doesnt support the compression + XContentBuilder invalidXContentBuilder3 = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x4.getName()) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, KNNEngine.FAISS) + .endObject() + .endObject(); + + expectThrows( + ValidationException.class, + () -> typeParser.parse( + TEST_FIELD_NAME, + xContentBuilderToMap(invalidXContentBuilder3), + buildParserContext(TEST_INDEX_NAME, settings) + ) + ); + } + + private void validateBuilderAfterParsing( + KNNVectorFieldMapper.Builder builder, + KNNEngine expectedEngine, + SpaceType expectedSpaceType, + VectorDataType expectedVectorDataType, + CompressionLevel expectedResolvedCompressionLevel, + CompressionLevel expectedOriginalCompressionLevel, + Mode expectedMode, + boolean shouldUsesBinaryQFramework + ) { + assertEquals(expectedEngine, builder.getOriginalParameters().getResolvedKnnMethodContext().getKnnEngine()); + assertEquals(expectedSpaceType, builder.getOriginalParameters().getResolvedKnnMethodContext().getSpaceType()); + assertEquals(expectedVectorDataType, builder.getKnnMethodConfigContext().getVectorDataType()); + + assertEquals(expectedResolvedCompressionLevel, builder.getKnnMethodConfigContext().getCompressionLevel()); + assertEquals(expectedOriginalCompressionLevel, CompressionLevel.fromName(builder.getOriginalParameters().getCompressionLevel())); + assertEquals(expectedMode, Mode.fromName(builder.getOriginalParameters().getMode())); + assertEquals(expectedMode, builder.getKnnMethodConfigContext().getMode()); + assertFalse(builder.getOriginalParameters().getResolvedKnnMethodContext().getMethodComponentContext().getParameters().isEmpty()); + + if (shouldUsesBinaryQFramework) { + assertEquals( + QFrameBitEncoder.NAME, + ((MethodComponentContext) builder.getOriginalParameters() + .getResolvedKnnMethodContext() + .getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER)).getName() + ); + assertEquals( + expectedResolvedCompressionLevel.numBitsForFloat32(), + (int) ((MethodComponentContext) builder.getOriginalParameters() + .getResolvedKnnMethodContext() + .getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER)).getParameters().get(QFrameBitEncoder.BITCOUNT_PARAM) + ); + } else { + assertTrue( + builder.getOriginalParameters().getResolvedKnnMethodContext().getMethodComponentContext().getParameters().isEmpty() + || builder.getOriginalParameters() + .getResolvedKnnMethodContext() + .getMethodComponentContext() + .getParameters() + .containsKey(METHOD_ENCODER_PARAMETER) == false + || QFrameBitEncoder.NAME.equals( + ((MethodComponentContext) builder.getOriginalParameters() + .getResolvedKnnMethodContext() + .getMethodComponentContext() + .getParameters() + .get(METHOD_ENCODER_PARAMETER)).getName() + ) == false + ); + } + } + private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder() { return LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() .name(TEST_FIELD_NAME) diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index 925ba0fff..ea9203196 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -21,7 +21,6 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.Mode; -import org.opensearch.knn.index.mapper.ModeBasedResolver; import org.opensearch.knn.index.query.parser.RescoreParser; import java.util.List; @@ -30,7 +29,6 @@ import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; @@ -39,7 +37,6 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER; import static org.opensearch.knn.common.KNNConstants.NAME; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; @@ -71,29 +68,16 @@ public class ModeAndCompressionIT extends KNNRestTestCase { 1.0f, 2.0f }; + private static final String[] COMPRESSION_LEVELS = new String[] { + CompressionLevel.x2.getName(), + CompressionLevel.x4.getName(), + CompressionLevel.x8.getName(), + CompressionLevel.x16.getName(), + CompressionLevel.x32.getName() }; + @SneakyThrows public void testIndexCreation_whenInvalid_thenFail() { XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(FIELD_NAME) - .field("type", "knn_vector") - .field("dimension", DIMENSION) - .field(MODE_PARAMETER, "on_disk") - .field(COMPRESSION_LEVEL_PARAMETER, "16x") - .startObject(KNN_METHOD) - .field(NAME, METHOD_HNSW) - .field(KNN_ENGINE, FAISS_NAME) - .startObject(PARAMETERS) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); - String mapping1 = builder.toString(); - expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, mapping1)); - - builder = XContentFactory.jsonBuilder() .startObject() .startObject("properties") .startObject(FIELD_NAME) @@ -139,7 +123,7 @@ public void testIndexCreation_whenInvalid_thenFail() { @SneakyThrows public void testIndexCreation_whenValid_ThenSucceed() { XContentBuilder builder; - for (CompressionLevel compressionLevel : ModeBasedResolver.SUPPORTED_COMPRESSION_LEVELS) { + for (String compressionLevel : COMPRESSION_LEVELS) { String indexName = INDEX_NAME + compressionLevel; builder = XContentFactory.jsonBuilder() .startObject() @@ -147,16 +131,23 @@ public void testIndexCreation_whenValid_ThenSucceed() { .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", DIMENSION) - .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel) .endObject() .endObject() .endObject(); String mapping = builder.toString(); validateIndex(indexName, mapping); - validateSearch(indexName, METHOD_PARAMETER_EF_SEARCH, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH); + logger.info("Compression level {}", compressionLevel); + validateSearch( + indexName, + METHOD_PARAMETER_EF_SEARCH, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, + compressionLevel, + Mode.NOT_CONFIGURED.getName() + ); } - for (CompressionLevel compressionLevel : ModeBasedResolver.SUPPORTED_COMPRESSION_LEVELS) { + for (String compressionLevel : COMPRESSION_LEVELS) { for (String mode : Mode.NAMES_ARRAY) { String indexName = INDEX_NAME + compressionLevel + "_" + mode; builder = XContentFactory.jsonBuilder() @@ -166,13 +157,20 @@ public void testIndexCreation_whenValid_ThenSucceed() { .field("type", "knn_vector") .field("dimension", DIMENSION) .field(MODE_PARAMETER, mode) - .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel) .endObject() .endObject() .endObject(); String mapping = builder.toString(); validateIndex(indexName, mapping); - validateSearch(indexName, METHOD_PARAMETER_EF_SEARCH, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH); + logger.info("Compression level {}", compressionLevel); + validateSearch( + indexName, + METHOD_PARAMETER_EF_SEARCH, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, + compressionLevel, + mode + ); } } @@ -190,7 +188,14 @@ public void testIndexCreation_whenValid_ThenSucceed() { .endObject(); String mapping = builder.toString(); validateIndex(indexName, mapping); - validateSearch(indexName, METHOD_PARAMETER_EF_SEARCH, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH); + logger.info("Compression level {}", CompressionLevel.NOT_CONFIGURED.getName()); + validateSearch( + indexName, + METHOD_PARAMETER_EF_SEARCH, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, + CompressionLevel.NOT_CONFIGURED.getName(), + mode + ); } } @@ -252,7 +257,7 @@ public void testTraining_whenInvalid_thenFail() { public void testTraining_whenValid_thenSucceed() { setupTrainingIndex(); XContentBuilder builder; - for (CompressionLevel compressionLevel : ModeBasedResolver.SUPPORTED_COMPRESSION_LEVELS) { + for (String compressionLevel : CompressionLevel.NAMES_ARRAY) { String indexName = INDEX_NAME + compressionLevel; String modelId = indexName; builder = XContentFactory.jsonBuilder() @@ -261,7 +266,7 @@ public void testTraining_whenValid_thenSucceed() { .field(TRAIN_FIELD_PARAMETER, TRAINING_FIELD_NAME) .field(KNNConstants.DIMENSION, DIMENSION) .field(MODEL_DESCRIPTION, "") - .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel) .endObject(); validateTraining(modelId, builder); builder = XContentFactory.jsonBuilder() @@ -275,10 +280,16 @@ public void testTraining_whenValid_thenSucceed() { .endObject(); String mapping = builder.toString(); validateIndex(indexName, mapping); - validateSearch(indexName, METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NLIST_DEFAULT); + validateSearch( + indexName, + METHOD_PARAMETER_NPROBES, + METHOD_PARAMETER_NLIST_DEFAULT, + compressionLevel, + Mode.NOT_CONFIGURED.getName() + ); } - for (CompressionLevel compressionLevel : ModeBasedResolver.SUPPORTED_COMPRESSION_LEVELS) { + for (String compressionLevel : CompressionLevel.NAMES_ARRAY) { for (String mode : Mode.NAMES_ARRAY) { String indexName = INDEX_NAME + compressionLevel + "_" + mode; String modelId = indexName; @@ -288,7 +299,7 @@ public void testTraining_whenValid_thenSucceed() { .field(TRAIN_FIELD_PARAMETER, TRAINING_FIELD_NAME) .field(KNNConstants.DIMENSION, DIMENSION) .field(MODEL_DESCRIPTION, "") - .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel) .field(MODE_PARAMETER, mode) .endObject(); validateTraining(modelId, builder); @@ -303,7 +314,7 @@ public void testTraining_whenValid_thenSucceed() { .endObject(); String mapping = builder.toString(); validateIndex(indexName, mapping); - validateSearch(indexName, METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NLIST_DEFAULT); + validateSearch(indexName, METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NLIST_DEFAULT, compressionLevel, mode); } } @@ -330,7 +341,13 @@ public void testTraining_whenValid_thenSucceed() { .endObject(); String mapping = builder.toString(); validateIndex(indexName, mapping); - validateSearch(indexName, METHOD_PARAMETER_NPROBES, METHOD_PARAMETER_NLIST_DEFAULT); + validateSearch( + indexName, + METHOD_PARAMETER_NPROBES, + METHOD_PARAMETER_NLIST_DEFAULT, + CompressionLevel.NOT_CONFIGURED.getName(), + mode + ); } } @@ -381,7 +398,13 @@ private void validateTraining(String modelId, XContentBuilder builder) { } @SneakyThrows - private void validateSearch(String indexName, String methodParameterName, int methodParameterValue) { + private void validateSearch( + String indexName, + String methodParameterName, + int methodParameterValue, + String compressionLevelString, + String mode + ) { // Basic search Response response = searchKNNIndex( indexName, @@ -436,7 +459,9 @@ private void validateSearch(String indexName, String methodParameterName, int me String exactSearchResponseBody = EntityUtils.toString(exactSearchResponse.getEntity()); List exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME); assertEquals(NUM_DOCS, exactSearchKnnResults.size()); - Assert.assertEquals(exactSearchKnnResults, knnResults); + if (CompressionLevel.x4.getName().equals(compressionLevelString) == false && Mode.ON_DISK.getName().equals(mode)) { + Assert.assertEquals(exactSearchKnnResults, knnResults); + } // Search with rescore response = searchKNNIndex( @@ -464,6 +489,8 @@ private void validateSearch(String indexName, String methodParameterName, int me responseBody = EntityUtils.toString(response.getEntity()); knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); assertEquals(K, knnResults.size()); - Assert.assertEquals(exactSearchKnnResults, knnResults); + if (CompressionLevel.x4.getName().equals(compressionLevelString) == false && Mode.ON_DISK.getName().equals(mode)) { + Assert.assertEquals(exactSearchKnnResults, knnResults); + } } } diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 4399b3318..5dbd0fc8b 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -18,6 +18,7 @@ import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNLibrary; +import org.opensearch.knn.index.engine.ResolvedMethodContext; import org.opensearch.test.OpenSearchTestCase; public class LibraryInitializedSupplierTests extends OpenSearchTestCase { @@ -105,5 +106,15 @@ public Boolean isInitialized() { public void setInitialized(Boolean isInitialized) { this.initialized = isInitialized; } + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + SpaceType spaceType + ) { + return null; + } } } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 151626ef5..30c5d33a1 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -304,7 +304,7 @@ public void testTrainingIndexSize() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - getDefaultKNNMethodContext(), + getDefaultKNNMethodContextForModel(), dimension, trainingIndexName, "training-field", @@ -353,7 +353,7 @@ public void testTrainIndexSize_whenDataTypeIsBinary() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - getDefaultKNNMethodContext(), + getDefaultKNNMethodContextForModel(), dimension, trainingIndexName, "training-field", @@ -403,7 +403,7 @@ public void testTrainIndexSize_whenDataTypeIsByte() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - getDefaultKNNMethodContext(), + getDefaultKNNMethodContextForModel(), dimension, trainingIndexName, "training-field", diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 79292fb53..2c423e0ef 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -44,7 +44,6 @@ import java.util.List; import java.util.Map; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -52,7 +51,7 @@ public class TrainingModelRequestTests extends KNNTestCase { public void testStreams() throws IOException { String modelId = "test-model-id"; - KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContextForModel(); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -142,7 +141,7 @@ public void testStreams() throws IOException { public void testGetters() { String modelId = "test-model-id"; - KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContextForModel(); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -170,7 +169,6 @@ public void testGetters() { trainingModelRequest.setTrainingDataSizeInKB(trainingSetSizeInKB); assertEquals(modelId, trainingModelRequest.getModelId()); - assertEquals(knnMethodContext, trainingModelRequest.getKnnMethodContext()); assertEquals(dimension, trainingModelRequest.getDimension()); assertEquals(trainingIndex, trainingModelRequest.getTrainingIndex()); assertEquals(trainingField, trainingModelRequest.getTrainingField()); @@ -187,18 +185,13 @@ public void testValidation_invalid_modelIdAlreadyExists() { // Setup the training request String modelId = "test-model-id"; - KNNEngine knnEngine = mock(KNNEngine.class); - when(knnEngine.validateMethod(any(), any())).thenReturn(null); - when(knnEngine.isTrainingRequired(any())).thenReturn(true); - KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( modelId, - knnMethodContext, + getDefaultKNNMethodContextForModel(), dimension, trainingIndex, trainingField, @@ -251,18 +244,13 @@ public void testValidation_blocked_modelId() { // Setup the training request String modelId = "test-model-id"; - KNNEngine knnEngine = mock(KNNEngine.class); - when(knnEngine.validateMethod(any(), any())).thenReturn(null); - when(knnEngine.isTrainingRequired(any())).thenReturn(true); - KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( modelId, - knnMethodContext, + getDefaultKNNMethodContextForModel(), dimension, trainingIndex, trainingField, @@ -298,48 +286,26 @@ public void testValidation_invalid_invalidMethodContext() { String modelId = "test-model-id"; // Mock throwing an exception on validation - KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - String validationExceptionMessage = "knn method invalid"; - ValidationException validationException = new ValidationException(); - validationException.addValidationError(validationExceptionMessage); - when(knnMethodContext.validate(any())).thenReturn(validationException); - - when(knnMethodContext.isTrainingRequired()).thenReturn(false); - when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; - TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null, - VectorDataType.DEFAULT, - Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + ValidationException validationException = expectThrows( + ValidationException.class, + () -> new TrainingModelRequest( + modelId, + getDefaultKNNMethodContext(), + dimension, + trainingIndex, + trainingField, + null, + null, + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED + ) ); - - // Mock the model dao to return null so that no exception is produced - ModelDao modelDao = mock(ModelDao.class); - when(modelDao.getMetadata(modelId)).thenReturn(null); - - // This cluster service will result in no validation exceptions - ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension); - - // Initialize static components with the mocks - TrainingModelRequest.initialize(modelDao, clusterService); - - // Test that validation produces model already exists error message - ActionRequestValidationException exception = trainingModelRequest.validate(); - assertNotNull(exception); - List validationErrors = exception.validationErrors(); - assertEquals(2, validationErrors.size()); - assertTrue(validationErrors.get(0).contains(validationExceptionMessage)); - assertTrue(validationErrors.get(1).contains("Method does not require training.")); + assertTrue(validationException.getMessage().contains("engine from training context")); } public void testValidation_invalid_trainingIndexDoesNotExist() {