Skip to content

Commit f61a27d

Browse files
authored
[Java] Tidy up MemorySegments lifecycle (#1069)
This PR tidies up the lifecycle of `MemorySegment`s by using specific confined `Arena`s where possible, and specific index-bound `Arena`s where the lifetime needs to be bound to the lifetime of an index. It addresses all remaining usages of `CuVSResources#arena` after #1024 and #1045; as such, we can remove `CuVSResources#arena` completely, and force native memory usages to be tracked and dealt with specifically. This would towards the goal of #1037 Authors: - Lorenzo Dematté (https://github.com/ldematte) - MithunR (https://github.com/mythrocks) - Ben Frederickson (https://github.com/benfred) Approvers: - Chris Hegarty (https://github.com/ChrisHegarty) - MithunR (https://github.com/mythrocks) URL: #1069
1 parent 0719080 commit f61a27d

File tree

5 files changed

+132
-166
lines changed

5 files changed

+132
-166
lines changed

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java

Lines changed: 60 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,22 @@
2020
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT_BYTE_SIZE;
2121
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG;
2222
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG_BYTE_SIZE;
23-
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_POINTER;
24-
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.*;
23+
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.HOST_TO_DEVICE;
24+
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.INFER_DIRECTION;
25+
import static com.nvidia.cuvs.internal.common.Util.allocateRMMSegment;
2526
import static com.nvidia.cuvs.internal.common.Util.buildMemorySegment;
2627
import static com.nvidia.cuvs.internal.common.Util.checkCuVSError;
2728
import static com.nvidia.cuvs.internal.common.Util.concatenate;
2829
import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy;
2930
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
30-
import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t;
3131
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceBuild;
3232
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceDeserialize;
3333
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceIndexCreate;
3434
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceIndexDestroy;
3535
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceIndex_t;
3636
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceSearch;
3737
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceSerialize;
38-
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMAlloc;
3938
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMFree;
40-
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamGet;
4139
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync;
4240
import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads;
4341

@@ -130,6 +128,9 @@ public void destroyIndex() {
130128
bruteForceIndexReference.datasetBytes);
131129
checkCuVSError(returnValue, "cuvsRMMFree");
132130
}
131+
if (bruteForceIndexReference.tensorDataArena != null) {
132+
bruteForceIndexReference.tensorDataArena.close();
133+
}
133134
} finally {
134135
destroyed = true;
135136
}
@@ -143,49 +144,41 @@ public void destroyIndex() {
143144
* index
144145
*/
145146
private IndexReference build(DatasetImpl dataset, BruteForceIndexParams bruteForceIndexParams) {
146-
try (var localArena = Arena.ofConfined()) {
147-
long rows = dataset.size();
148-
long cols = dataset.dimensions();
149-
150-
Arena arena = resources.getArena();
151-
MemorySegment datasetMemSegment = dataset.asMemorySegment();
152-
153-
long cuvsResources = resources.getHandle();
154-
155-
omp_set_num_threads(bruteForceIndexParams.getNumWriterThreads());
147+
long rows = dataset.size();
148+
long cols = dataset.dimensions();
156149

157-
MemorySegment datasetMemorySegment = localArena.allocate(C_POINTER);
150+
MemorySegment datasetMemSegment = dataset.asMemorySegment();
158151

159-
long datasetBytes = C_FLOAT_BYTE_SIZE * rows * cols;
160-
var returnValue = cuvsRMMAlloc(cuvsResources, datasetMemorySegment, datasetBytes);
161-
checkCuVSError(returnValue, "cuvsRMMAlloc");
152+
long cuvsResources = resources.getHandle();
162153

163-
// IMPORTANT: this should only come AFTER cuvsRMMAlloc call
164-
MemorySegment datasetMemorySegmentP = datasetMemorySegment.get(C_POINTER, 0);
154+
omp_set_num_threads(bruteForceIndexParams.getNumWriterThreads());
155+
long datasetBytes = C_FLOAT_BYTE_SIZE * rows * cols;
156+
MemorySegment datasetMemorySegmentP = allocateRMMSegment(cuvsResources, datasetBytes);
165157

166-
cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION);
158+
cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION);
167159

168-
long[] datasetShape = {rows, cols};
169-
MemorySegment datasetTensor =
170-
prepareTensor(arena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1);
160+
long[] datasetShape = {rows, cols};
161+
var tensorDataArena = Arena.ofShared();
162+
MemorySegment datasetTensor =
163+
prepareTensor(tensorDataArena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1);
171164

172-
var indexReference =
173-
new IndexReference(datasetMemorySegmentP, datasetBytes, createBruteForceIndex());
165+
var indexReference =
166+
new IndexReference(
167+
datasetMemorySegmentP, datasetBytes, tensorDataArena, createBruteForceIndex());
174168

175-
returnValue = cuvsStreamSync(cuvsResources);
176-
checkCuVSError(returnValue, "cuvsStreamSync");
169+
var returnValue = cuvsStreamSync(cuvsResources);
170+
checkCuVSError(returnValue, "cuvsStreamSync");
177171

178-
returnValue =
179-
cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, indexReference.indexPtr);
180-
checkCuVSError(returnValue, "cuvsBruteForceBuild");
172+
returnValue =
173+
cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, indexReference.indexPtr);
174+
checkCuVSError(returnValue, "cuvsBruteForceBuild");
181175

182-
returnValue = cuvsStreamSync(cuvsResources);
183-
checkCuVSError(returnValue, "cuvsStreamSync");
176+
returnValue = cuvsStreamSync(cuvsResources);
177+
checkCuVSError(returnValue, "cuvsStreamSync");
184178

185-
omp_set_num_threads(1);
179+
omp_set_num_threads(1);
186180

187-
return indexReference;
188-
}
181+
return indexReference;
189182
}
190183

191184
/**
@@ -203,12 +196,11 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
203196
long numQueries = cuvsQuery.getQueryVectors().length;
204197
long numBlocks = cuvsQuery.getTopK() * numQueries;
205198
int vectorDimension = numQueries > 0 ? cuvsQuery.getQueryVectors()[0].length : 0;
206-
Arena arena = resources.getArena();
207199

208200
SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_LONG);
209201
SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_FLOAT);
210-
MemorySegment neighborsMemorySegment = arena.allocate(neighborsSequenceLayout);
211-
MemorySegment distancesMemorySegment = arena.allocate(distancesSequenceLayout);
202+
MemorySegment neighborsMemorySegment = localArena.allocate(neighborsSequenceLayout);
203+
MemorySegment distancesMemorySegment = localArena.allocate(distancesSequenceLayout);
212204

213205
// prepare the prefiltering data
214206
long prefilterDataLength = 0;
@@ -217,77 +209,59 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
217209
if (prefilters != null && prefilters.length > 0) {
218210
BitSet concatenatedFilters = concatenate(prefilters, cuvsQuery.getNumDocs());
219211
long[] filters = concatenatedFilters.toLongArray();
220-
prefilterDataMemorySegment = buildMemorySegment(arena, filters);
212+
prefilterDataMemorySegment = buildMemorySegment(localArena, filters);
221213
prefilterDataLength = (long) cuvsQuery.getNumDocs() * prefilters.length;
222214
}
223215

224-
MemorySegment querySeg = buildMemorySegment(arena, cuvsQuery.getQueryVectors());
216+
MemorySegment querySeg = buildMemorySegment(localArena, cuvsQuery.getQueryVectors());
225217

226218
int topk = cuvsQuery.getTopK();
227219
long cuvsResources = resources.getHandle();
228-
MemorySegment stream = arena.allocate(cudaStream_t);
229-
var returnValue = cuvsStreamGet(cuvsResources, stream);
230-
checkCuVSError(returnValue, "cuvsStreamGet");
231-
232-
MemorySegment queriesD = localArena.allocate(C_POINTER);
233-
MemorySegment neighborsD = localArena.allocate(C_POINTER);
234-
MemorySegment distancesD = localArena.allocate(C_POINTER);
235220

236221
long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension;
237222
long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk;
238223
long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk;
239224
long prefilterBytes = 0; // size assigned later
240225

241-
returnValue = cuvsRMMAlloc(cuvsResources, queriesD, queriesBytes);
242-
checkCuVSError(returnValue, "cuvsRMMAlloc");
243-
returnValue = cuvsRMMAlloc(cuvsResources, neighborsD, neighborsBytes);
244-
checkCuVSError(returnValue, "cuvsRMMAlloc");
245-
returnValue = cuvsRMMAlloc(cuvsResources, distancesD, distanceBytes);
246-
checkCuVSError(returnValue, "cuvsRMMAlloc");
247-
248-
// IMPORTANT: these three should only come AFTER cuvsRMMAlloc calls
249-
MemorySegment queriesDP = queriesD.get(C_POINTER, 0);
250-
MemorySegment neighborsDP = neighborsD.get(C_POINTER, 0);
251-
MemorySegment distancesDP = distancesD.get(C_POINTER, 0);
226+
MemorySegment queriesDP = allocateRMMSegment(cuvsResources, queriesBytes);
227+
MemorySegment neighborsDP = allocateRMMSegment(cuvsResources, neighborsBytes);
228+
MemorySegment distancesDP = allocateRMMSegment(cuvsResources, distanceBytes);
252229
MemorySegment prefilterDP = MemorySegment.NULL;
253230

254231
cudaMemcpy(queriesDP, querySeg, queriesBytes, INFER_DIRECTION);
255232

256233
long[] queriesShape = {numQueries, vectorDimension};
257-
MemorySegment queriesTensor = prepareTensor(arena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
234+
MemorySegment queriesTensor =
235+
prepareTensor(localArena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
258236
long[] neighborsShape = {numQueries, topk};
259237
MemorySegment neighborsTensor =
260-
prepareTensor(arena, neighborsDP, neighborsShape, 0, 64, 2, 2, 1);
238+
prepareTensor(localArena, neighborsDP, neighborsShape, 0, 64, 2, 2, 1);
261239
long[] distancesShape = {numQueries, topk};
262240
MemorySegment distancesTensor =
263-
prepareTensor(arena, distancesDP, distancesShape, 2, 32, 2, 2, 1);
241+
prepareTensor(localArena, distancesDP, distancesShape, 2, 32, 2, 2, 1);
264242

265-
MemorySegment prefilter = cuvsFilter.allocate(arena);
243+
MemorySegment prefilter = cuvsFilter.allocate(localArena);
266244
MemorySegment prefilterTensor;
267245

268246
if (prefilterDataMemorySegment == MemorySegment.NULL) {
269247
cuvsFilter.type(prefilter, 0); // NO_FILTER
270248
cuvsFilter.addr(prefilter, 0);
271249
} else {
272250
long[] prefilterShape = {(prefilterDataLength + 31) / 32};
273-
274-
MemorySegment prefilterD = localArena.allocate(C_POINTER);
275251
long prefilterLen = prefilterShape[0];
276252
prefilterBytes = C_INT_BYTE_SIZE * prefilterLen;
277253

278-
returnValue = cuvsRMMAlloc(cuvsResources, prefilterD, prefilterBytes);
279-
checkCuVSError(returnValue, "cuvsRMMAlloc");
280-
prefilterDP = prefilterD.get(C_POINTER, 0);
254+
prefilterDP = allocateRMMSegment(cuvsResources, prefilterBytes);
281255

282256
cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE);
283257

284-
prefilterTensor = prepareTensor(arena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1);
258+
prefilterTensor = prepareTensor(localArena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1);
285259

286260
cuvsFilter.type(prefilter, 2);
287261
cuvsFilter.addr(prefilter, prefilterTensor.address());
288262
}
289263

290-
returnValue = cuvsStreamSync(cuvsResources);
264+
var returnValue = cuvsStreamSync(cuvsResources);
291265
checkCuVSError(returnValue, "cuvsStreamSync");
292266

293267
returnValue =
@@ -361,13 +335,12 @@ private static MemorySegment createBruteForceIndex() {
361335
try (var localArena = Arena.ofConfined()) {
362336
MemorySegment indexPtrPtr = localArena.allocate(cuvsBruteForceIndex_t);
363337
// cuvsBruteForceIndexCreate gets a pointer to a cuvsBruteForceIndex_t, which is defined as a
364-
// pointer to
365-
// cuvsBruteForceIndex.
366-
// It's basically a "out" parameter: the C functions will create the index and "return back" a
367-
// pointer to it.
338+
// pointer to cuvsBruteForceIndex.
339+
// It's basically an "out" parameter: the C functions will create the index and "return back"
340+
// a pointer to it.
368341
// The "out parameter" pointer is needed only for the duration of the function invocation (it
369-
// could be a stack
370-
// pointer, in C) so we allocate it from our localArena, unwrap it and return it.
342+
// could be a stack pointer, in C) so we allocate it from our localArena, unwrap it and return
343+
// it.
371344
var returnValue = cuvsBruteForceIndexCreate(indexPtrPtr);
372345
checkCuVSError(returnValue, "cuvsBruteForceIndexCreate");
373346
return indexPtrPtr.get(cuvsBruteForceIndex_t, 0);
@@ -498,23 +471,31 @@ public BruteForceIndexImpl build() throws Throwable {
498471
}
499472

500473
/**
501-
* Holds the memory reference to a BRUTEFORCE index and its associated dataset
474+
* Holds the memory reference to a BRUTEFORCE index, its associated dataset, and the arena used to allocate
475+
* input data
502476
*/
503477
private static class IndexReference {
504478

505479
private final MemorySegment datasetPtr;
506480
private final long datasetBytes;
481+
private final Arena tensorDataArena;
507482
private final MemorySegment indexPtr;
508483

509-
private IndexReference(MemorySegment datasetPtr, long datasetBytes, MemorySegment indexPtr) {
484+
private IndexReference(
485+
MemorySegment datasetPtr,
486+
long datasetBytes,
487+
Arena tensorDataArena,
488+
MemorySegment indexPtr) {
510489
this.datasetPtr = datasetPtr;
511490
this.datasetBytes = datasetBytes;
491+
this.tensorDataArena = tensorDataArena;
512492
this.indexPtr = indexPtr;
513493
}
514494

515495
private IndexReference(MemorySegment indexPtr) {
516496
this.datasetPtr = MemorySegment.NULL;
517497
this.datasetBytes = 0;
498+
this.tensorDataArena = null;
518499
this.indexPtr = indexPtr;
519500
}
520501
}

0 commit comments

Comments
 (0)