Skip to content

Implement off-heap quantized scoring #14863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ static void compressBytes(byte[] raw, byte[] compressed) {
private byte[] bytesA;
private byte[] bytesB;
private byte[] halfBytesA;
private byte[] halfBytesAPacked;
private byte[] halfBytesB;
private byte[] halfBytesBPacked;
private float[] floatsA;
Expand Down Expand Up @@ -84,6 +85,9 @@ public void init() {
}
// pack the half byte arrays
if (size % 2 == 0) {
halfBytesAPacked = new byte[(size + 1) >> 1];
compressBytes(halfBytesA, halfBytesAPacked);

halfBytesBPacked = new byte[(size + 1) >> 1];
compressBytes(halfBytesB, halfBytesBPacked);
}
Expand Down Expand Up @@ -146,7 +150,7 @@ public int binaryHalfByteScalarPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedhalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
}
Expand All @@ -159,13 +163,30 @@ public int binaryHalfByteVectorPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
int v = VectorUtil.int4DotProductPacked(halfBytesA, halfBytesBPacked);
int v = VectorUtil.int4DotProductSinglePacked(halfBytesA, halfBytesBPacked);
if (v != expectedhalfByteDotProduct) {
throw new RuntimeException("Expected " + expectedhalfByteDotProduct + " but got " + v);
}
return v;
}

@Benchmark
public int binaryHalfByteScalarPackedPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
return VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int binaryHalfByteVectorPackedPacked() {
if (size % 2 != 0) {
throw new RuntimeException("Size must be even for this benchmark");
}
return VectorUtil.int4DotProductBothPacked(halfBytesAPacked, halfBytesBPacked);
}

@Benchmark
public float floatCosineScalar() {
return VectorUtil.cosine(floatsA, floatsB);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ private FlatVectorScorerUtil() {}
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
return IMPL.getLucene99FlatVectorsScorer();
}

public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return IMPL.getLucene99ScalarQuantizedVectorsScorer();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FloatToFloatFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
Expand Down Expand Up @@ -242,7 +243,7 @@ public float score(int vectorOrdinal) throws IOException {
values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES));
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
int dotProduct = VectorUtil.int4DotProductSinglePacked(targetBytes, compressedVector);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
Expand Down Expand Up @@ -296,11 +297,6 @@ public void setScoringOrdinal(int node) throws IOException {
}
}

@FunctionalInterface
private interface FloatToFloatFunction {
float apply(float f);
}

private static final class ScalarQuantizedRandomVectorScorerSupplier
implements RandomVectorScorerSupplier {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
Expand Down Expand Up @@ -70,7 +70,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {

final byte bits;
final boolean compress;
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
final FlatVectorsScorer flatVectorScorer;

/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
Expand Down Expand Up @@ -117,8 +117,7 @@ public Lucene99ScalarQuantizedVectorsFormat(
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
this.flatVectorScorer = FlatVectorScorerUtil.getLucene99ScalarQuantizedVectorsScorer();
}

public static float calculateDefaultConfidenceInterval(int vectorDimension) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,24 +155,35 @@ public int dotProduct(byte[] a, byte[] b) {
}

@Override
public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) {
assert (apacked && bpacked) == false;
if (apacked || bpacked) {
byte[] packed = apacked ? a : b;
byte[] unpacked = apacked ? b : a;
int total = 0;
for (int i = 0; i < packed.length; i++) {
byte packedByte = packed[i];
byte unpacked1 = unpacked[i];
byte unpacked2 = unpacked[i + packed.length];
total += (packedByte & 0x0F) * unpacked2;
total += ((packedByte & 0xFF) >> 4) * unpacked1;
}
return total;
}
public int int4DotProduct(byte[] a, byte[] b) {
return dotProduct(a, b);
}

@Override
public int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) {
int total = 0;
for (int i = 0; i < packed.length; i++) {
byte packedByte = packed[i];
byte unpacked1 = unpacked[i];
byte unpacked2 = unpacked[i + packed.length];
total += (packedByte & 0x0F) * unpacked2;
total += ((packedByte & 0xFF) >> 4) * unpacked1;
}
return total;
}

@Override
public int int4DotProductBothPacked(byte[] a, byte[] b) {
int total = 0;
for (int i = 0; i < a.length; i++) {
byte aByte = a[i];
byte bByte = b[i];
total += (aByte & 0x0F) * (bByte & 0x0F);
total += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4);
}
return total;
}

@Override
public float cosine(byte[] a, byte[] b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
import org.apache.lucene.store.IndexInput;

/** Default provider returning scalar implementations. */
Expand All @@ -40,6 +41,11 @@ public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return DefaultFlatVectorScorer.INSTANCE;
}

@Override
public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
}

@Override
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) {
return new PostingDecodingUtil(input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ public interface VectorUtilSupport {
/** Returns the dot product computed over signed bytes. */
int dotProduct(byte[] a, byte[] b);

/** Returns the dot product over the computed bytes, assuming the values are int4 encoded. */
int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked);
int int4DotProduct(byte[] a, byte[] b);

int int4DotProductSinglePacked(byte[] unpacked, byte[] packed);

int int4DotProductBothPacked(byte[] a, byte[] b);

/** Returns the cosine similarity between the two byte vectors. */
float cosine(byte[] a, byte[] b);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ public static VectorizationProvider getInstance() {
/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
public abstract FlatVectorsScorer getLucene99FlatVectorsScorer();

/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer();

/** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */
public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.util;

/**
* Simple interface to map one float to another (useful in scaling scores).
*
* @lucene.internal
*/
@FunctionalInterface
public interface FloatToFloatFunction {
float apply(float f);
}
14 changes: 11 additions & 3 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public static int int4DotProduct(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
return IMPL.int4DotProduct(a, false, b, false);
return IMPL.int4DotProduct(a, b);
}

/**
Expand All @@ -189,12 +189,20 @@ public static int int4DotProduct(byte[] a, byte[] b) {
* @param packed the packed vector, of length {@code (unpacked.length + 1) / 2}
* @return the value of the dot product of the two vectors
*/
public static int int4DotProductPacked(byte[] unpacked, byte[] packed) {
public static int int4DotProductSinglePacked(byte[] unpacked, byte[] packed) {
if (packed.length != ((unpacked.length + 1) >> 1)) {
throw new IllegalArgumentException(
"vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length);
}
return IMPL.int4DotProduct(unpacked, false, packed, true);
return IMPL.int4DotProductSinglePacked(unpacked, packed);
}

public static int int4DotProductBothPacked(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException(
"vector dimensions differ: " + a.length + " != " + b.length);
}
return IMPL.int4DotProductBothPacked(a, b);
}

/**
Expand Down
Loading
Loading