Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix graph merge stats size calculation #1844

Merged
merged 9 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make calculations easier to read
Signed-off-by: Ryan Bogan <rbogan@amazon.com>
  • Loading branch information
ryanbogan committed Jul 18, 2024
commit c1e3109f7bb56700c24b6a83f4d602ef97656f57
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
public class KNNCodecUtil {
// Floats are 4 bytes in size
public static final int FLOAT_BYTE_SIZE = 4;
// References to objects are 4 bytes in size
public static final int JAVA_REFERENCE_SIZE = 4;
// References to objects are 8 bytes in size
public static final int JAVA_REFERENCE_SIZE = 8;
// Each array in Java has a header that is 12 bytes
public static final int JAVA_ARRAY_HEADER_SIZE = 12;
// Java rounds each array size up to multiples of 8 bytes
Expand Down Expand Up @@ -75,24 +75,26 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect
* @return rough estimate of number of bytes used to store an array with the given parameters
*/
public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) {
// For more information about array storage in Java, visit https://www.javamex.com/tutorials/memory/array_memory_usage.shtml
// Note: java reference size is 8 bytes for 64 bit machines and 4 bytes for 32 bit machines, this method assumes 64 bit
if (serializationMode == SerializationMode.ARRAY) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE;
vectorSize = roundVectorSize(vectorSize);
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE;
vectorsSize = roundVectorSize(vectorsSize);
return vectorsSize;
int sizeOfVector = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE;
int sizeOfVectorArray = roundVectorSize(sizeOfVector) * numVectors;
int sizeOfReferenceArray = numVectors * JAVA_REFERENCE_SIZE + JAVA_ARRAY_HEADER_SIZE;
sizeOfReferenceArray = roundVectorSize(sizeOfReferenceArray);
return sizeOfReferenceArray + sizeOfVectorArray;
} else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE;
vectorSize = roundVectorSize(vectorSize);
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
vectorsSize = roundVectorSize(vectorsSize);
return vectorsSize;
int sizeOfVector = vectorLength * FLOAT_BYTE_SIZE;
int sizeOfVectorArray = roundVectorSize(sizeOfVector) * numVectors;
int sizeOfReferenceArray = numVectors * JAVA_REFERENCE_SIZE;
sizeOfReferenceArray = roundVectorSize(sizeOfReferenceArray);
return sizeOfReferenceArray + sizeOfVectorArray;
} else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) {
int vectorSize = vectorLength;
vectorSize = roundVectorSize(vectorSize);
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
vectorsSize = roundVectorSize(vectorsSize);
return vectorsSize;
int sizeOfVector = vectorLength;
int sizeOfVectorArray = roundVectorSize(sizeOfVector) * numVectors;
int sizeOfReferenceArray = numVectors * JAVA_REFERENCE_SIZE;
sizeOfReferenceArray = roundVectorSize(sizeOfReferenceArray);
return sizeOfReferenceArray + sizeOfVectorArray;
} else {
throw new IllegalStateException("Unreachable code");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ public void testCalculateArraySize() {

// Array
SerializationMode serializationMode = SerializationMode.ARRAY;
assertEquals(256, calculateArraySize(numVectors, vectorLength, serializationMode));
assertEquals(272, calculateArraySize(numVectors, vectorLength, serializationMode));

// Collection of floats
serializationMode = SerializationMode.COLLECTION_OF_FLOATS;
assertEquals(176, calculateArraySize(numVectors, vectorLength, serializationMode));
assertEquals(192, calculateArraySize(numVectors, vectorLength, serializationMode));

// Collection of bytes
serializationMode = SerializationMode.COLLECTIONS_OF_BYTES;
assertEquals(80, calculateArraySize(numVectors, vectorLength, serializationMode));
assertEquals(96, calculateArraySize(numVectors, vectorLength, serializationMode));
}
}
Loading