Skip to content

Add BSON Binary Subtype 9 support for vector storage and retrieval. #1528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Rename Type to DataType.
  • Loading branch information
vbabanin committed Oct 17, 2024
commit 8de3743b0ce76490aed2c7bc2ef4bd735eb887a4
2 changes: 1 addition & 1 deletion bson/src/main/org/bson/Float32Vector.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public final class Float32Vector extends Vector {
private final float[] vectorData;

Float32Vector(final float[] vectorData) {
super(Dtype.FLOAT32);
super(DataType.FLOAT32);
this.vectorData = assertNotNull(vectorData);
}

Expand Down
2 changes: 1 addition & 1 deletion bson/src/main/org/bson/Int8Vector.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public final class Int8Vector extends Vector {
private byte[] vectorData;

Int8Vector(final byte[] vectorData) {
super(Dtype.INT8);
super(DataType.INT8);
this.vectorData = assertNotNull(vectorData);
}

Expand Down
4 changes: 2 additions & 2 deletions bson/src/main/org/bson/PackedBitVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public final class PackedBitVector extends Vector {
private final byte[] vectorData;

PackedBitVector(final byte[] vectorData, final byte padding) {
super(Dtype.PACKED_BIT);
super(DataType.PACKED_BIT);
this.vectorData = assertNotNull(vectorData);
this.padding = padding;
}
Expand All @@ -64,7 +64,7 @@ public byte[] getVectorArray() {
* Returns the padding value for this vector.
*
* <p>Padding refers to the number of least-significant bits in the final byte that are ignored when retrieving the vector data, as not
* all {@link Dtype}'s have a bit length equal to a multiple of 8, and hence do not fit squarely into a certain number of bytes.</p>
* all {@link DataType}'s have a bit length equal to a multiple of 8, and hence do not fit squarely into a certain number of bytes.</p>
* <p>
* NOTE: The underlying byte array is not copied; changes to the returned array will be reflected in this instance.
*
Expand Down
54 changes: 27 additions & 27 deletions bson/src/main/org/bson/Vector.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

/**
* Represents a vector that is stored and retrieved using the BSON Binary Subtype 9 format.
* This class supports multiple vector {@link Dtype}'s and provides static methods to create
* This class supports multiple vector {@link DataType}'s and provides static methods to create
* vectors.
* <p>
* Vectors are densely packed arrays of numbers, all the same type, which are stored efficiently
Expand All @@ -34,16 +34,16 @@
* @since BINARY_VECTOR
*/
public abstract class Vector {
private final Dtype vectorType;
private final DataType vectorType;

Vector(final Dtype vectorType) {
Vector(final DataType vectorType) {
this.vectorType = vectorType;
}

/**
* Creates a vector with the {@link Dtype#PACKED_BIT} data type.
* Creates a vector with the {@link DataType#PACKED_BIT} data type.
* <p>
* A {@link Dtype#PACKED_BIT} vector is a binary quantized vector where each element of a vector is represented by a single bit (0 or 1). Each byte
* A {@link DataType#PACKED_BIT} vector is a binary quantized vector where each element of a vector is represented by a single bit (0 or 1). Each byte
* can hold up to 8 bits (vector elements). The padding parameter is used to specify how many bits in the final byte should be ignored.</p>
*
* <p>For example, a vector with two bytes and a padding of 4 would have the following structure:</p>
Expand All @@ -59,7 +59,7 @@ public abstract class Vector {
*
* @param vectorData The byte array representing the packed bit vector data. Each byte can store 8 bits.
* @param padding The number of bits (0 to 7) to ignore in the final byte of the vector data.
* @return A {@link PackedBitVector} instance with the {@link Dtype#PACKED_BIT} data type.
* @return A {@link PackedBitVector} instance with the {@link DataType#PACKED_BIT} data type.
* @throws IllegalArgumentException If the padding value is greater than 7.
*/
public static PackedBitVector packedBitVector(final byte[] vectorData, final byte padding) {
Expand All @@ -70,32 +70,32 @@ public static PackedBitVector packedBitVector(final byte[] vectorData, final byt
}

/**
* Creates a vector with the {@link Dtype#INT8} data type.
* Creates a vector with the {@link DataType#INT8} data type.
*
* <p>A {@link Dtype#INT8} vector is a vector of 8-bit signed integers where each byte in the vector represents an element of a vector,
* <p>A {@link DataType#INT8} vector is a vector of 8-bit signed integers where each byte in the vector represents an element of a vector,
* with values in the range [-128, 127].</p>
* <p>
* NOTE: The byte array `vectorData` is not copied; changes to the provided array will be reflected
* in the created {@link Int8Vector} instance.
*
* @param vectorData The byte array representing the {@link Dtype#INT8} vector data.
* @return A {@link Int8Vector} instance with the {@link Dtype#INT8} data type.
* @param vectorData The byte array representing the {@link DataType#INT8} vector data.
* @return A {@link Int8Vector} instance with the {@link DataType#INT8} data type.
*/
public static Int8Vector int8Vector(final byte[] vectorData) {
notNull("vectorData", vectorData);
return new Int8Vector(vectorData);
}

/**
* Creates a vector with the {@link Dtype#FLOAT32} data type.
* Creates a vector with the {@link DataType#FLOAT32} data type.
* <p>
* A {@link Dtype#FLOAT32} vector is a vector of floating-point numbers, where each element in the vector is a float.</p>
* A {@link DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the vector is a float.</p>
* <p>
* NOTE: The float array `vectorData` is not copied; changes to the provided array will be reflected
* in the created {@link Float32Vector} instance.
*
* @param vectorData The float array representing the {@link Dtype#FLOAT32} vector data.
* @return A {@link Float32Vector} instance with the {@link Dtype#FLOAT32} data type.
* @param vectorData The float array representing the {@link DataType#FLOAT32} vector data.
* @return A {@link Float32Vector} instance with the {@link DataType#FLOAT32} data type.
*/
public static Float32Vector floatVector(final float[] vectorData) {
notNull("vectorData", vectorData);
Expand All @@ -106,49 +106,49 @@ public static Float32Vector floatVector(final float[] vectorData) {
* Returns the {@link PackedBitVector}.
*
* @return {@link PackedBitVector}.
* @throws IllegalStateException if this vector is not of type {@link Dtype#PACKED_BIT}. Use {@link #getDataType()} to check the vector
* @throws IllegalStateException if this vector is not of type {@link DataType#PACKED_BIT}. Use {@link #getDataType()} to check the vector
* type before calling this method.
*/
public PackedBitVector asPackedBitVector() {
ensureType(Dtype.PACKED_BIT);
ensureType(DataType.PACKED_BIT);
return (PackedBitVector) this;
}

/**
* Returns the {@link Int8Vector}.
*
* @return {@link Int8Vector}.
* @throws IllegalStateException if this vector is not of type {@link Dtype#INT8}. Use {@link #getDataType()} to check the vector
* @throws IllegalStateException if this vector is not of type {@link DataType#INT8}. Use {@link #getDataType()} to check the vector
* type before calling this method.
*/
public Int8Vector asInt8Vector() {
ensureType(Dtype.INT8);
ensureType(DataType.INT8);
return (Int8Vector) this;
}

/**
* Returns the {@link Float32Vector}.
*
* @return {@link Float32Vector}.
* @throws IllegalStateException if this vector is not of type {@link Dtype#FLOAT32}. Use {@link #getDataType()} to check the vector
* @throws IllegalStateException if this vector is not of type {@link DataType#FLOAT32}. Use {@link #getDataType()} to check the vector
* type before calling this method.
*/
public Float32Vector asFloat32Vector() {
ensureType(Dtype.FLOAT32);
ensureType(DataType.FLOAT32);
return (Float32Vector) this;
}

/**
* Returns {@link Dtype} of the vector.
* Returns {@link DataType} of the vector.
*
* @return the data type of the vector.
*/
public Dtype getDataType() {
public DataType getDataType() {
return this.vectorType;
}


private void ensureType(final Dtype expected) {
private void ensureType(final DataType expected) {
if (this.vectorType != expected) {
throw new IllegalStateException("Expected vector type " + expected + " but found " + this.vectorType);
}
Expand All @@ -160,7 +160,7 @@ private void ensureType(final Dtype expected) {
* Each dtype determines how the data in the vector is stored, including how many bits are used to represent each element
* in the vector.
*/
public enum Dtype {
public enum DataType {
/**
* An INT8 vector is a vector of 8-bit signed integers. The vector is stored as an array of bytes, where each byte
* represents a signed integer in the range [-128, 127].
Expand All @@ -178,16 +178,16 @@ public enum Dtype {

private final byte value;

Dtype(final byte value) {
DataType(final byte value) {
this.value = value;
}

/**
* Returns the byte value associated with this {@link Dtype}.
* Returns the byte value associated with this {@link DataType}.
*
* <p>This value is used in the BSON binary format to indicate the data type of the vector.</p>
*
* @return the byte value representing the {@link Dtype}.
* @return the byte value representing the {@link DataType}.
*/
public byte getValue() {
return value;
Expand Down
24 changes: 12 additions & 12 deletions bson/src/main/org/bson/internal/vector/VectorHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ private VectorHelper() {
private static final int FLOAT_SIZE = 4;

public static byte[] encodeVectorToBinary(final Vector vector) {
Vector.Dtype dtype = vector.getDataType();
switch (dtype) {
Vector.DataType dataType = vector.getDataType();
switch (dataType) {
case INT8:
return writeVector(dtype.getValue(), (byte) 0, vector.asInt8Vector().getVectorArray());
return writeVector(dataType.getValue(), (byte) 0, vector.asInt8Vector().getVectorArray());
case PACKED_BIT:
PackedBitVector packedBitVector = vector.asPackedBitVector();
return writeVector(dtype.getValue(), packedBitVector.getPadding(), packedBitVector.getVectorArray());
return writeVector(dataType.getValue(), packedBitVector.getPadding(), packedBitVector.getVectorArray());
case FLOAT32:
return writeVector(dtype.getValue(), (byte) 0, vector.asFloat32Vector().getVectorArray());
return writeVector(dataType.getValue(), (byte) 0, vector.asFloat32Vector().getVectorArray());

default:
throw new AssertionError("Unknown vector dtype: " + dtype);
throw new AssertionError("Unknown vector dtype: " + dataType);
}
}

Expand All @@ -72,9 +72,9 @@ public static byte[] encodeVectorToBinary(final Vector vector) {
public static Vector decodeBinaryToVector(final byte[] encodedVector) {
isTrue("Vector encoded array length must be at least 2.", encodedVector.length >= METADATA_SIZE);

Vector.Dtype dtype = determineVectorDType(encodedVector[0]);
Vector.DataType dataType = determineVectorDType(encodedVector[0]);
byte padding = encodedVector[1];
switch (dtype) {
switch (dataType) {
case INT8:
isTrue("Padding must be 0 for INT8 data type.", padding == 0);
byte[] int8Vector = getVectorBytesWithoutMetadata(encodedVector);
Expand All @@ -91,7 +91,7 @@ public static Vector decodeBinaryToVector(final byte[] encodedVector) {
return Vector.floatVector(readLittleEndianFloats(encodedVector));

default:
throw new AssertionError("Unknown vector data type: " + dtype);
throw new AssertionError("Unknown vector data type: " + dataType);
}
}

Expand Down Expand Up @@ -147,9 +147,9 @@ private static float[] readLittleEndianFloats(final byte[] encodedVector) {
return floatArray;
}

public static Vector.Dtype determineVectorDType(final byte dType) {
Vector.Dtype[] values = Vector.Dtype.values();
for (Vector.Dtype value : values) {
public static Vector.DataType determineVectorDType(final byte dType) {
Vector.DataType[] values = Vector.DataType.values();
for (Vector.DataType value : values) {
if (value.getValue() == dType) {
return value;
}
Expand Down
2 changes: 1 addition & 1 deletion bson/src/test/unit/org/bson/BsonBinaryWriterTest.java
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the refactoring done to the expectedValues in this test 👍

Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

public class BsonBinaryWriterTest {

private static final byte FLOAT32_DTYPE = Vector.Dtype.FLOAT32.getValue();
private static final byte FLOAT32_DTYPE = Vector.DataType.FLOAT32.getValue();
private static final int ZERO_PADDING = 0;

private BsonBinaryWriter writer;
Expand Down
8 changes: 4 additions & 4 deletions bson/src/test/unit/org/bson/VectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void shouldCreateInt8Vector() {

// then
assertNotNull(vector);
assertEquals(Vector.Dtype.INT8, vector.getDataType());
assertEquals(Vector.DataType.INT8, vector.getDataType());
assertArrayEquals(data, vector.getVectorArray());
}

Expand All @@ -61,7 +61,7 @@ void shouldCreateFloat32Vector() {

// then
assertNotNull(vector);
assertEquals(Vector.Dtype.FLOAT32, vector.getDataType());
assertEquals(Vector.DataType.FLOAT32, vector.getDataType());
assertArrayEquals(data, vector.getVectorArray());
}

Expand All @@ -87,7 +87,7 @@ void shouldCreatePackedBitVector(final byte validPadding) {

// then
assertNotNull(vector);
assertEquals(Vector.Dtype.PACKED_BIT, vector.getDataType());
assertEquals(Vector.DataType.PACKED_BIT, vector.getDataType());
assertArrayEquals(data, vector.getVectorArray());
assertEquals(validPadding, vector.getPadding());
}
Expand Down Expand Up @@ -127,7 +127,7 @@ void shouldCreatePackedBitVectorWithZeroPaddingAndEmptyData() {

// then
assertNotNull(vector);
assertEquals(Vector.Dtype.PACKED_BIT, vector.getDataType());
assertEquals(Vector.DataType.PACKED_BIT, vector.getDataType());
assertArrayEquals(data, vector.getVectorArray());
assertEquals(padding, vector.getPadding());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
import static org.junit.jupiter.api.Assertions.assertThrows;

class VectorHelperTest {
private static final byte FLOAT32_DTYPE = Vector.Dtype.FLOAT32.getValue();
private static final byte INT8_DTYPE = Vector.Dtype.INT8.getValue();
private static final byte PACKED_BIT_DTYPE = Vector.Dtype.PACKED_BIT.getValue();
private static final byte FLOAT32_DTYPE = Vector.DataType.FLOAT32.getValue();
private static final byte INT8_DTYPE = Vector.DataType.INT8.getValue();
private static final byte PACKED_BIT_DTYPE = Vector.DataType.PACKED_BIT.getValue();
public static final int ZERO_PADDING = 0;

@ParameterizedTest(name = "{index}: {0}")
Expand All @@ -54,7 +54,7 @@ void shouldDecodeFloatVector(final Float32Vector expectedFloatVector, final byte
Float32Vector decodedVector = (Float32Vector) VectorHelper.decodeBinaryToVector(bsonEncodedVector);

// then
assertEquals(Vector.Dtype.FLOAT32, decodedVector.getDataType());
assertEquals(Vector.DataType.FLOAT32, decodedVector.getDataType());
assertArrayEquals(expectedFloatVector.getVectorArray(), decodedVector.getVectorArray());
}

Expand Down Expand Up @@ -102,7 +102,7 @@ void shouldDecodeInt8Vector(final Int8Vector expectedInt8Vector, final byte[] bs
Int8Vector decodedVector = (Int8Vector) VectorHelper.decodeBinaryToVector(bsonEncodedVector);

// then
assertEquals(Vector.Dtype.INT8, decodedVector.getDataType());
assertEquals(Vector.DataType.INT8, decodedVector.getDataType());
assertArrayEquals(expectedInt8Vector.getVectorArray(), decodedVector.getVectorArray());
}

Expand Down Expand Up @@ -135,7 +135,7 @@ void shouldDecodePackedBitVector(final PackedBitVector expectedPackedBitVector,
PackedBitVector decodedVector = (PackedBitVector) VectorHelper.decodeBinaryToVector(bsonEncodedVector);

// then
assertEquals(Vector.Dtype.PACKED_BIT, decodedVector.getDataType());
assertEquals(Vector.DataType.PACKED_BIT, decodedVector.getDataType());
assertArrayEquals(expectedPackedBitVector.getVectorArray(), decodedVector.getVectorArray());
assertEquals(expectedPackedBitVector.getPadding(), decodedVector.getPadding());
}
Expand Down Expand Up @@ -220,12 +220,12 @@ void shouldThrowExceptionForInvalidPackedBitArrayPaddingWhenDecodeEmptyVector(fi
@Test
void shouldDetermineVectorDType() {
// given
Vector.Dtype[] values = Vector.Dtype.values();
Vector.DataType[] values = Vector.DataType.values();

for (Vector.Dtype value : values) {
for (Vector.DataType value : values) {
// when
byte dtype = value.getValue();
Vector.Dtype actual = VectorHelper.determineVectorDType(dtype);
Vector.DataType actual = VectorHelper.determineVectorDType(dtype);

// then
assertEquals(value, actual);
Expand Down
Loading