Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,22 @@
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT_BYTE_SIZE;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG_BYTE_SIZE;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_POINTER;
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.*;
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.HOST_TO_DEVICE;
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.INFER_DIRECTION;
import static com.nvidia.cuvs.internal.common.Util.allocateRMMSegment;
import static com.nvidia.cuvs.internal.common.Util.buildMemorySegment;
import static com.nvidia.cuvs.internal.common.Util.checkCuVSError;
import static com.nvidia.cuvs.internal.common.Util.concatenate;
import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy;
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceBuild;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceDeserialize;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceIndexCreate;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceIndexDestroy;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceIndex_t;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceSearch;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceSerialize;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMAlloc;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMFree;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamGet;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync;
import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads;

Expand Down Expand Up @@ -130,6 +128,9 @@ public void destroyIndex() {
bruteForceIndexReference.datasetBytes);
checkCuVSError(returnValue, "cuvsRMMFree");
}
if (bruteForceIndexReference.tensorDataArena != null) {
bruteForceIndexReference.tensorDataArena.close();
}
} finally {
destroyed = true;
}
Expand All @@ -143,49 +144,41 @@ public void destroyIndex() {
* index
*/
private IndexReference build(DatasetImpl dataset, BruteForceIndexParams bruteForceIndexParams) {
try (var localArena = Arena.ofConfined()) {
long rows = dataset.size();
long cols = dataset.dimensions();

Arena arena = resources.getArena();
MemorySegment datasetMemSegment = dataset.asMemorySegment();

long cuvsResources = resources.getHandle();

omp_set_num_threads(bruteForceIndexParams.getNumWriterThreads());
long rows = dataset.size();
long cols = dataset.dimensions();

MemorySegment datasetMemorySegment = localArena.allocate(C_POINTER);
MemorySegment datasetMemSegment = dataset.asMemorySegment();

long datasetBytes = C_FLOAT_BYTE_SIZE * rows * cols;
var returnValue = cuvsRMMAlloc(cuvsResources, datasetMemorySegment, datasetBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");
long cuvsResources = resources.getHandle();

// IMPORTANT: this should only come AFTER cuvsRMMAlloc call
MemorySegment datasetMemorySegmentP = datasetMemorySegment.get(C_POINTER, 0);
omp_set_num_threads(bruteForceIndexParams.getNumWriterThreads());
long datasetBytes = C_FLOAT_BYTE_SIZE * rows * cols;
MemorySegment datasetMemorySegmentP = allocateRMMSegment(cuvsResources, datasetBytes);

cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION);
cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION);

long[] datasetShape = {rows, cols};
MemorySegment datasetTensor =
prepareTensor(arena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1);
long[] datasetShape = {rows, cols};
var tensorDataArena = Arena.ofShared();
MemorySegment datasetTensor =
prepareTensor(tensorDataArena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1);

var indexReference =
new IndexReference(datasetMemorySegmentP, datasetBytes, createBruteForceIndex());
var indexReference =
new IndexReference(
datasetMemorySegmentP, datasetBytes, tensorDataArena, createBruteForceIndex());

returnValue = cuvsStreamSync(cuvsResources);
checkCuVSError(returnValue, "cuvsStreamSync");
var returnValue = cuvsStreamSync(cuvsResources);
checkCuVSError(returnValue, "cuvsStreamSync");

returnValue =
cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, indexReference.indexPtr);
checkCuVSError(returnValue, "cuvsBruteForceBuild");
returnValue =
cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, indexReference.indexPtr);
checkCuVSError(returnValue, "cuvsBruteForceBuild");

returnValue = cuvsStreamSync(cuvsResources);
checkCuVSError(returnValue, "cuvsStreamSync");
returnValue = cuvsStreamSync(cuvsResources);
checkCuVSError(returnValue, "cuvsStreamSync");

omp_set_num_threads(1);
omp_set_num_threads(1);

return indexReference;
}
return indexReference;
}

/**
Expand All @@ -203,12 +196,11 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
long numQueries = cuvsQuery.getQueryVectors().length;
long numBlocks = cuvsQuery.getTopK() * numQueries;
int vectorDimension = numQueries > 0 ? cuvsQuery.getQueryVectors()[0].length : 0;
Arena arena = resources.getArena();

SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_LONG);
SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_FLOAT);
MemorySegment neighborsMemorySegment = arena.allocate(neighborsSequenceLayout);
MemorySegment distancesMemorySegment = arena.allocate(distancesSequenceLayout);
MemorySegment neighborsMemorySegment = localArena.allocate(neighborsSequenceLayout);
MemorySegment distancesMemorySegment = localArena.allocate(distancesSequenceLayout);

// prepare the prefiltering data
long prefilterDataLength = 0;
Expand All @@ -217,77 +209,59 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
if (prefilters != null && prefilters.length > 0) {
BitSet concatenatedFilters = concatenate(prefilters, cuvsQuery.getNumDocs());
long[] filters = concatenatedFilters.toLongArray();
prefilterDataMemorySegment = buildMemorySegment(arena, filters);
prefilterDataMemorySegment = buildMemorySegment(localArena, filters);
prefilterDataLength = (long) cuvsQuery.getNumDocs() * prefilters.length;
}

MemorySegment querySeg = buildMemorySegment(arena, cuvsQuery.getQueryVectors());
MemorySegment querySeg = buildMemorySegment(localArena, cuvsQuery.getQueryVectors());

int topk = cuvsQuery.getTopK();
long cuvsResources = resources.getHandle();
MemorySegment stream = arena.allocate(cudaStream_t);
var returnValue = cuvsStreamGet(cuvsResources, stream);
checkCuVSError(returnValue, "cuvsStreamGet");

MemorySegment queriesD = localArena.allocate(C_POINTER);
MemorySegment neighborsD = localArena.allocate(C_POINTER);
MemorySegment distancesD = localArena.allocate(C_POINTER);

long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension;
long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk;
long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk;
long prefilterBytes = 0; // size assigned later

returnValue = cuvsRMMAlloc(cuvsResources, queriesD, queriesBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");
returnValue = cuvsRMMAlloc(cuvsResources, neighborsD, neighborsBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");
returnValue = cuvsRMMAlloc(cuvsResources, distancesD, distanceBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");

// IMPORTANT: these three should only come AFTER cuvsRMMAlloc calls
MemorySegment queriesDP = queriesD.get(C_POINTER, 0);
MemorySegment neighborsDP = neighborsD.get(C_POINTER, 0);
MemorySegment distancesDP = distancesD.get(C_POINTER, 0);
MemorySegment queriesDP = allocateRMMSegment(cuvsResources, queriesBytes);
MemorySegment neighborsDP = allocateRMMSegment(cuvsResources, neighborsBytes);
MemorySegment distancesDP = allocateRMMSegment(cuvsResources, distanceBytes);
MemorySegment prefilterDP = MemorySegment.NULL;

cudaMemcpy(queriesDP, querySeg, queriesBytes, INFER_DIRECTION);

long[] queriesShape = {numQueries, vectorDimension};
MemorySegment queriesTensor = prepareTensor(arena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
MemorySegment queriesTensor =
prepareTensor(localArena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
long[] neighborsShape = {numQueries, topk};
MemorySegment neighborsTensor =
prepareTensor(arena, neighborsDP, neighborsShape, 0, 64, 2, 2, 1);
prepareTensor(localArena, neighborsDP, neighborsShape, 0, 64, 2, 2, 1);
long[] distancesShape = {numQueries, topk};
MemorySegment distancesTensor =
prepareTensor(arena, distancesDP, distancesShape, 2, 32, 2, 2, 1);
prepareTensor(localArena, distancesDP, distancesShape, 2, 32, 2, 2, 1);

MemorySegment prefilter = cuvsFilter.allocate(arena);
MemorySegment prefilter = cuvsFilter.allocate(localArena);
MemorySegment prefilterTensor;

if (prefilterDataMemorySegment == MemorySegment.NULL) {
cuvsFilter.type(prefilter, 0); // NO_FILTER
cuvsFilter.addr(prefilter, 0);
} else {
long[] prefilterShape = {(prefilterDataLength + 31) / 32};

MemorySegment prefilterD = localArena.allocate(C_POINTER);
long prefilterLen = prefilterShape[0];
prefilterBytes = C_INT_BYTE_SIZE * prefilterLen;

returnValue = cuvsRMMAlloc(cuvsResources, prefilterD, prefilterBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");
prefilterDP = prefilterD.get(C_POINTER, 0);
prefilterDP = allocateRMMSegment(cuvsResources, prefilterBytes);

cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE);

prefilterTensor = prepareTensor(arena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1);
prefilterTensor = prepareTensor(localArena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1);

cuvsFilter.type(prefilter, 2);
cuvsFilter.addr(prefilter, prefilterTensor.address());
}

returnValue = cuvsStreamSync(cuvsResources);
var returnValue = cuvsStreamSync(cuvsResources);
checkCuVSError(returnValue, "cuvsStreamSync");

returnValue =
Expand Down Expand Up @@ -361,13 +335,12 @@ private static MemorySegment createBruteForceIndex() {
try (var localArena = Arena.ofConfined()) {
MemorySegment indexPtrPtr = localArena.allocate(cuvsBruteForceIndex_t);
// cuvsBruteForceIndexCreate gets a pointer to a cuvsBruteForceIndex_t, which is defined as a
// pointer to
// cuvsBruteForceIndex.
// It's basically a "out" parameter: the C functions will create the index and "return back" a
// pointer to it.
// pointer to cuvsBruteForceIndex.
// It's basically an "out" parameter: the C functions will create the index and "return back"
// a pointer to it.
// The "out parameter" pointer is needed only for the duration of the function invocation (it
// could be a stack
// pointer, in C) so we allocate it from our localArena, unwrap it and return it.
// could be a stack pointer, in C) so we allocate it from our localArena, unwrap it and return
// it.
var returnValue = cuvsBruteForceIndexCreate(indexPtrPtr);
checkCuVSError(returnValue, "cuvsBruteForceIndexCreate");
return indexPtrPtr.get(cuvsBruteForceIndex_t, 0);
Expand Down Expand Up @@ -498,23 +471,31 @@ public BruteForceIndexImpl build() throws Throwable {
}

/**
* Holds the memory reference to a BRUTEFORCE index and its associated dataset
* Holds the memory reference to a BRUTEFORCE index, its associated dataset, and the arena used to allocate
* input data
*/
private static class IndexReference {

private final MemorySegment datasetPtr;
private final long datasetBytes;
private final Arena tensorDataArena;
private final MemorySegment indexPtr;

private IndexReference(MemorySegment datasetPtr, long datasetBytes, MemorySegment indexPtr) {
private IndexReference(
MemorySegment datasetPtr,
long datasetBytes,
Arena tensorDataArena,
MemorySegment indexPtr) {
this.datasetPtr = datasetPtr;
this.datasetBytes = datasetBytes;
this.tensorDataArena = tensorDataArena;
this.indexPtr = indexPtr;
}

private IndexReference(MemorySegment indexPtr) {
this.datasetPtr = MemorySegment.NULL;
this.datasetBytes = 0;
this.tensorDataArena = null;
this.indexPtr = indexPtr;
}
}
Expand Down
Loading