Skip to content

Implement off-heap quantized scoring #14863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,12 @@ private FlatVectorScorerUtil() {}
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
return IMPL.getLucene99FlatVectorsScorer();
}

/**
* Returns a FlatVectorsScorer that supports the quantized Lucene99 format. Scorers retrieved
* through this method may be optimized on certain platforms.
*/
public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return IMPL.getLucene99ScalarQuantizedVectorsScorer();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
Expand Down Expand Up @@ -70,7 +70,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {

final byte bits;
final boolean compress;
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
final FlatVectorsScorer flatVectorScorer;

/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
Expand Down Expand Up @@ -117,8 +117,7 @@ public Lucene99ScalarQuantizedVectorsFormat(
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer();
}

public static float calculateDefaultConfidenceInterval(int vectorDimension) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
import org.apache.lucene.store.IndexInput;

/** Default provider returning scalar implementations. */
Expand All @@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return DefaultFlatVectorScorer.INSTANCE;
}

@Override
public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
}

@Override
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) {
return new PostingDecodingUtil(input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ public static VectorizationProvider getInstance() {
/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
public abstract FlatVectorsScorer getLucene99FlatVectorsScorer();

/** Returns a FlatVectorsScorer that supports the quantized Lucene99 format. */
public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer();

/** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */
public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.internal.vectorization;

import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery;
import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.factory;
import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.getSegment;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.FloatToFloatFunction;
import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.MemorySegmentScorer;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;

class Lucene99MemorySegmentScalarQuantizedScorer
extends RandomVectorScorer.AbstractRandomVectorScorer {

private final VectorSimilarityFunction function;
private final QuantizedByteVectorValues values;
private final MemorySegmentAccessInput input;
private final MemorySegmentScorer scorer;
private final FloatToFloatFunction scaler;
private final float constMultiplier;
private final int vectorByteSize;
private final int entrySize;
private final MemorySegment query;
private final float queryOffset;
private final byte[][] docScratch;

public Lucene99MemorySegmentScalarQuantizedScorer(
VectorSimilarityFunction function,
QuantizedByteVectorValues values,
MemorySegmentAccessInput input,
float[] target) {

super(values);
this.function = function;
this.values = values;
this.input = input;
this.scorer = factory(function, values, false);
this.scaler = factory(function);

ScalarQuantizer quantizer = values.getScalarQuantizer();
this.constMultiplier = quantizer.getConstantMultiplier();
this.vectorByteSize = values.getVectorByteLength();
this.entrySize = vectorByteSize + Float.BYTES;

byte[] targetBytes = new byte[target.length];
this.queryOffset = quantizeQuery(target, targetBytes, function, quantizer);
this.query = Arena.ofAuto().allocateFrom(JAVA_BYTE, targetBytes);

this.docScratch = new byte[1][];
}

@Override
public float score(int node) throws IOException {
MemorySegment segment = getSegment(input, entrySize, node, docScratch);
MemorySegment doc = segment.reinterpret(vectorByteSize);
float docOffset = segment.get(JAVA_FLOAT, vectorByteSize);
return scaler.scale(scorer.score(query, doc) * constMultiplier + queryOffset + docOffset);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.internal.vectorization;

import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.factory;
import static org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.getSegment;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.FloatToFloatFunction;
import org.apache.lucene.internal.vectorization.Lucene99MemorySegmentScalarQuantizedVectorScorer.MemorySegmentScorer;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;

class Lucene99MemorySegmentScalarQuantizedScorerSupplier implements RandomVectorScorerSupplier {

private final VectorSimilarityFunction function;
private final QuantizedByteVectorValues values;
private final MemorySegmentAccessInput input;
private final MemorySegmentScorer scorer;
private final FloatToFloatFunction scaler;
private final float constMultiplier;
private final int vectorByteSize;
private final int entrySize;

public Lucene99MemorySegmentScalarQuantizedScorerSupplier(
VectorSimilarityFunction function,
QuantizedByteVectorValues values,
MemorySegmentAccessInput input) {

this.function = function;
this.values = values;
this.input = input;
this.scorer = factory(function, values, true);
this.scaler = factory(function);
this.constMultiplier = values.getScalarQuantizer().getConstantMultiplier();
this.vectorByteSize = values.getVectorByteLength();
this.entrySize = vectorByteSize + Float.BYTES;
}

@Override
public UpdateableRandomVectorScorer scorer() {
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) {

private final MemorySegment[] doc = new MemorySegment[1];
private final float[] docOffset = new float[1];
private final byte[][] docScratch = new byte[1][];
private final byte[][] queryScratch = new byte[1][];

@Override
public void setScoringOrdinal(int node) throws IOException {
MemorySegment segment = getSegment(input, entrySize, node, docScratch);
doc[0] = segment.reinterpret(vectorByteSize);
docOffset[0] = segment.get(JAVA_FLOAT, vectorByteSize);
}

@Override
public float score(int node) throws IOException {
MemorySegment segment = getSegment(input, entrySize, node, queryScratch);
MemorySegment query = segment.reinterpret(vectorByteSize);
float queryOffset = segment.get(JAVA_FLOAT, vectorByteSize);
return scaler.scale(
scorer.score(query, doc[0]) * constMultiplier + queryOffset + docOffset[0]);
}
};
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new Lucene99MemorySegmentScalarQuantizedScorerSupplier(function, values, input);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.internal.vectorization;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;

public class Lucene99MemorySegmentScalarQuantizedVectorScorer implements FlatVectorsScorer {

public static final Lucene99MemorySegmentScalarQuantizedVectorScorer INSTANCE =
new Lucene99MemorySegmentScalarQuantizedVectorScorer();

private static final FlatVectorsScorer NON_QUANTIZED_DELEGATE =
Lucene99MemorySegmentFlatVectorsScorer.INSTANCE;

@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
if (vectorValues instanceof QuantizedByteVectorValues values
&& values.getSlice() instanceof MemorySegmentAccessInput input) {
return new Lucene99MemorySegmentScalarQuantizedScorerSupplier(
similarityFunction, values, input);
}
// It is possible to get to this branch during initial indexing and flush
return NON_QUANTIZED_DELEGATE.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
throws IOException {
if (vectorValues instanceof QuantizedByteVectorValues values
&& values.getSlice() instanceof MemorySegmentAccessInput input) {
checkDimensions(target.length, vectorValues.dimension());
return new Lucene99MemorySegmentScalarQuantizedScorer(
similarityFunction, values, input, target);
}
// It is possible to get to this branch during initial indexing and flush
return NON_QUANTIZED_DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException {
return NON_QUANTIZED_DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
}

@Override
public String toString() {
return getClass().getSimpleName() + "()";
}

private static void checkDimensions(int queryLen, int fieldLen) {
if (queryLen != fieldLen) {
throw new IllegalArgumentException(
"vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
}
}

static MemorySegment getSegment(
MemorySegmentAccessInput input, int entrySize, int node, byte[][] scratch)
throws IOException {
long pos = (long) entrySize * node;
MemorySegment segment = input.segmentSliceOrNull(pos, entrySize);
if (segment == null) {
if (scratch[0] == null) {
scratch[0] = new byte[entrySize];
}
input.readBytes(pos, scratch[0], 0, entrySize);
segment = MemorySegment.ofArray(scratch[0]);
}
return segment;
}

@FunctionalInterface
interface MemorySegmentScorer {
float score(MemorySegment query, MemorySegment doc);
}

@FunctionalInterface
interface FloatToFloatFunction {
float scale(float score);
}

static MemorySegmentScorer factory(
VectorSimilarityFunction function,
QuantizedByteVectorValues values,
boolean isScorerSupplier) {
return switch (function) {
case EUCLIDEAN -> {
if (values.getScalarQuantizer().getBits() < 7) {
// TODO
throw new UnsupportedOperationException();
}
yield PanamaVectorUtilSupport::squareDistance;
}
case DOT_PRODUCT, COSINE, MAXIMUM_INNER_PRODUCT -> {
if (values.getScalarQuantizer().getBits() <= 4) {
if (values.getVectorByteLength() != values.dimension()) {
if (isScorerSupplier) {
yield (query, doc) -> PanamaVectorUtilSupport.int4DotProduct(query, true, doc, true);
}
yield (query, doc) -> PanamaVectorUtilSupport.int4DotProduct(query, false, doc, true);
}
yield (query, doc) -> PanamaVectorUtilSupport.int4DotProduct(query, false, doc, false);
}
yield PanamaVectorUtilSupport::dotProduct;
}
};
}

static FloatToFloatFunction factory(VectorSimilarityFunction function) {
return switch (function) {
case EUCLIDEAN -> score -> (1 / (1f + score));
case DOT_PRODUCT, COSINE -> score -> Math.max((1f + score) / 2, 0);
case MAXIMUM_INNER_PRODUCT -> VectorUtil::scaleMaxInnerProductScore;
};
}
}
Loading
Loading