20
20
import static com .nvidia .cuvs .internal .common .LinkerHelper .C_INT_BYTE_SIZE ;
21
21
import static com .nvidia .cuvs .internal .common .LinkerHelper .C_LONG ;
22
22
import static com .nvidia .cuvs .internal .common .LinkerHelper .C_LONG_BYTE_SIZE ;
23
- import static com .nvidia .cuvs .internal .common .LinkerHelper .C_POINTER ;
24
- import static com .nvidia .cuvs .internal .common .Util .CudaMemcpyKind .*;
23
+ import static com .nvidia .cuvs .internal .common .Util .CudaMemcpyKind .HOST_TO_DEVICE ;
24
+ import static com .nvidia .cuvs .internal .common .Util .CudaMemcpyKind .INFER_DIRECTION ;
25
+ import static com .nvidia .cuvs .internal .common .Util .allocateRMMSegment ;
25
26
import static com .nvidia .cuvs .internal .common .Util .buildMemorySegment ;
26
27
import static com .nvidia .cuvs .internal .common .Util .checkCuVSError ;
27
28
import static com .nvidia .cuvs .internal .common .Util .concatenate ;
28
29
import static com .nvidia .cuvs .internal .common .Util .cudaMemcpy ;
29
30
import static com .nvidia .cuvs .internal .common .Util .prepareTensor ;
30
- import static com .nvidia .cuvs .internal .panama .headers_h .cudaStream_t ;
31
31
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceBuild ;
32
32
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceDeserialize ;
33
33
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceIndexCreate ;
34
34
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceIndexDestroy ;
35
35
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceIndex_t ;
36
36
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceSearch ;
37
37
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceSerialize ;
38
- import static com .nvidia .cuvs .internal .panama .headers_h .cuvsRMMAlloc ;
39
38
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsRMMFree ;
40
- import static com .nvidia .cuvs .internal .panama .headers_h .cuvsStreamGet ;
41
39
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsStreamSync ;
42
40
import static com .nvidia .cuvs .internal .panama .headers_h .omp_set_num_threads ;
43
41
@@ -130,6 +128,9 @@ public void destroyIndex() {
130
128
bruteForceIndexReference .datasetBytes );
131
129
checkCuVSError (returnValue , "cuvsRMMFree" );
132
130
}
131
+ if (bruteForceIndexReference .tensorDataArena != null ) {
132
+ bruteForceIndexReference .tensorDataArena .close ();
133
+ }
133
134
} finally {
134
135
destroyed = true ;
135
136
}
@@ -143,49 +144,41 @@ public void destroyIndex() {
143
144
* index
144
145
*/
145
146
private IndexReference build (DatasetImpl dataset , BruteForceIndexParams bruteForceIndexParams ) {
146
- try (var localArena = Arena .ofConfined ()) {
147
- long rows = dataset .size ();
148
- long cols = dataset .dimensions ();
149
-
150
- Arena arena = resources .getArena ();
151
- MemorySegment datasetMemSegment = dataset .asMemorySegment ();
152
-
153
- long cuvsResources = resources .getHandle ();
154
-
155
- omp_set_num_threads (bruteForceIndexParams .getNumWriterThreads ());
147
+ long rows = dataset .size ();
148
+ long cols = dataset .dimensions ();
156
149
157
- MemorySegment datasetMemorySegment = localArena . allocate ( C_POINTER );
150
+ MemorySegment datasetMemSegment = dataset . asMemorySegment ( );
158
151
159
- long datasetBytes = C_FLOAT_BYTE_SIZE * rows * cols ;
160
- var returnValue = cuvsRMMAlloc (cuvsResources , datasetMemorySegment , datasetBytes );
161
- checkCuVSError (returnValue , "cuvsRMMAlloc" );
152
+ long cuvsResources = resources .getHandle ();
162
153
163
- // IMPORTANT: this should only come AFTER cuvsRMMAlloc call
164
- MemorySegment datasetMemorySegmentP = datasetMemorySegment .get (C_POINTER , 0 );
154
+ omp_set_num_threads (bruteForceIndexParams .getNumWriterThreads ());
155
+ long datasetBytes = C_FLOAT_BYTE_SIZE * rows * cols ;
156
+ MemorySegment datasetMemorySegmentP = allocateRMMSegment (cuvsResources , datasetBytes );
165
157
166
- cudaMemcpy (datasetMemorySegmentP , datasetMemSegment , datasetBytes , INFER_DIRECTION );
158
+ cudaMemcpy (datasetMemorySegmentP , datasetMemSegment , datasetBytes , INFER_DIRECTION );
167
159
168
- long [] datasetShape = {rows , cols };
169
- MemorySegment datasetTensor =
170
- prepareTensor (arena , datasetMemorySegmentP , datasetShape , 2 , 32 , 2 , 2 , 1 );
160
+ long [] datasetShape = {rows , cols };
161
+ var tensorDataArena = Arena .ofShared ();
162
+ MemorySegment datasetTensor =
163
+ prepareTensor (tensorDataArena , datasetMemorySegmentP , datasetShape , 2 , 32 , 2 , 2 , 1 );
171
164
172
- var indexReference =
173
- new IndexReference (datasetMemorySegmentP , datasetBytes , createBruteForceIndex ());
165
+ var indexReference =
166
+ new IndexReference (
167
+ datasetMemorySegmentP , datasetBytes , tensorDataArena , createBruteForceIndex ());
174
168
175
- returnValue = cuvsStreamSync (cuvsResources );
176
- checkCuVSError (returnValue , "cuvsStreamSync" );
169
+ var returnValue = cuvsStreamSync (cuvsResources );
170
+ checkCuVSError (returnValue , "cuvsStreamSync" );
177
171
178
- returnValue =
179
- cuvsBruteForceBuild (cuvsResources , datasetTensor , 0 , 0.0f , indexReference .indexPtr );
180
- checkCuVSError (returnValue , "cuvsBruteForceBuild" );
172
+ returnValue =
173
+ cuvsBruteForceBuild (cuvsResources , datasetTensor , 0 , 0.0f , indexReference .indexPtr );
174
+ checkCuVSError (returnValue , "cuvsBruteForceBuild" );
181
175
182
- returnValue = cuvsStreamSync (cuvsResources );
183
- checkCuVSError (returnValue , "cuvsStreamSync" );
176
+ returnValue = cuvsStreamSync (cuvsResources );
177
+ checkCuVSError (returnValue , "cuvsStreamSync" );
184
178
185
- omp_set_num_threads (1 );
179
+ omp_set_num_threads (1 );
186
180
187
- return indexReference ;
188
- }
181
+ return indexReference ;
189
182
}
190
183
191
184
/**
@@ -203,12 +196,11 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
203
196
long numQueries = cuvsQuery .getQueryVectors ().length ;
204
197
long numBlocks = cuvsQuery .getTopK () * numQueries ;
205
198
int vectorDimension = numQueries > 0 ? cuvsQuery .getQueryVectors ()[0 ].length : 0 ;
206
- Arena arena = resources .getArena ();
207
199
208
200
SequenceLayout neighborsSequenceLayout = MemoryLayout .sequenceLayout (numBlocks , C_LONG );
209
201
SequenceLayout distancesSequenceLayout = MemoryLayout .sequenceLayout (numBlocks , C_FLOAT );
210
- MemorySegment neighborsMemorySegment = arena .allocate (neighborsSequenceLayout );
211
- MemorySegment distancesMemorySegment = arena .allocate (distancesSequenceLayout );
202
+ MemorySegment neighborsMemorySegment = localArena .allocate (neighborsSequenceLayout );
203
+ MemorySegment distancesMemorySegment = localArena .allocate (distancesSequenceLayout );
212
204
213
205
// prepare the prefiltering data
214
206
long prefilterDataLength = 0 ;
@@ -217,77 +209,59 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
217
209
if (prefilters != null && prefilters .length > 0 ) {
218
210
BitSet concatenatedFilters = concatenate (prefilters , cuvsQuery .getNumDocs ());
219
211
long [] filters = concatenatedFilters .toLongArray ();
220
- prefilterDataMemorySegment = buildMemorySegment (arena , filters );
212
+ prefilterDataMemorySegment = buildMemorySegment (localArena , filters );
221
213
prefilterDataLength = (long ) cuvsQuery .getNumDocs () * prefilters .length ;
222
214
}
223
215
224
- MemorySegment querySeg = buildMemorySegment (arena , cuvsQuery .getQueryVectors ());
216
+ MemorySegment querySeg = buildMemorySegment (localArena , cuvsQuery .getQueryVectors ());
225
217
226
218
int topk = cuvsQuery .getTopK ();
227
219
long cuvsResources = resources .getHandle ();
228
- MemorySegment stream = arena .allocate (cudaStream_t );
229
- var returnValue = cuvsStreamGet (cuvsResources , stream );
230
- checkCuVSError (returnValue , "cuvsStreamGet" );
231
-
232
- MemorySegment queriesD = localArena .allocate (C_POINTER );
233
- MemorySegment neighborsD = localArena .allocate (C_POINTER );
234
- MemorySegment distancesD = localArena .allocate (C_POINTER );
235
220
236
221
long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension ;
237
222
long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk ;
238
223
long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk ;
239
224
long prefilterBytes = 0 ; // size assigned later
240
225
241
- returnValue = cuvsRMMAlloc (cuvsResources , queriesD , queriesBytes );
242
- checkCuVSError (returnValue , "cuvsRMMAlloc" );
243
- returnValue = cuvsRMMAlloc (cuvsResources , neighborsD , neighborsBytes );
244
- checkCuVSError (returnValue , "cuvsRMMAlloc" );
245
- returnValue = cuvsRMMAlloc (cuvsResources , distancesD , distanceBytes );
246
- checkCuVSError (returnValue , "cuvsRMMAlloc" );
247
-
248
- // IMPORTANT: these three should only come AFTER cuvsRMMAlloc calls
249
- MemorySegment queriesDP = queriesD .get (C_POINTER , 0 );
250
- MemorySegment neighborsDP = neighborsD .get (C_POINTER , 0 );
251
- MemorySegment distancesDP = distancesD .get (C_POINTER , 0 );
226
+ MemorySegment queriesDP = allocateRMMSegment (cuvsResources , queriesBytes );
227
+ MemorySegment neighborsDP = allocateRMMSegment (cuvsResources , neighborsBytes );
228
+ MemorySegment distancesDP = allocateRMMSegment (cuvsResources , distanceBytes );
252
229
MemorySegment prefilterDP = MemorySegment .NULL ;
253
230
254
231
cudaMemcpy (queriesDP , querySeg , queriesBytes , INFER_DIRECTION );
255
232
256
233
long [] queriesShape = {numQueries , vectorDimension };
257
- MemorySegment queriesTensor = prepareTensor (arena , queriesDP , queriesShape , 2 , 32 , 2 , 2 , 1 );
234
+ MemorySegment queriesTensor =
235
+ prepareTensor (localArena , queriesDP , queriesShape , 2 , 32 , 2 , 2 , 1 );
258
236
long [] neighborsShape = {numQueries , topk };
259
237
MemorySegment neighborsTensor =
260
- prepareTensor (arena , neighborsDP , neighborsShape , 0 , 64 , 2 , 2 , 1 );
238
+ prepareTensor (localArena , neighborsDP , neighborsShape , 0 , 64 , 2 , 2 , 1 );
261
239
long [] distancesShape = {numQueries , topk };
262
240
MemorySegment distancesTensor =
263
- prepareTensor (arena , distancesDP , distancesShape , 2 , 32 , 2 , 2 , 1 );
241
+ prepareTensor (localArena , distancesDP , distancesShape , 2 , 32 , 2 , 2 , 1 );
264
242
265
- MemorySegment prefilter = cuvsFilter .allocate (arena );
243
+ MemorySegment prefilter = cuvsFilter .allocate (localArena );
266
244
MemorySegment prefilterTensor ;
267
245
268
246
if (prefilterDataMemorySegment == MemorySegment .NULL ) {
269
247
cuvsFilter .type (prefilter , 0 ); // NO_FILTER
270
248
cuvsFilter .addr (prefilter , 0 );
271
249
} else {
272
250
long [] prefilterShape = {(prefilterDataLength + 31 ) / 32 };
273
-
274
- MemorySegment prefilterD = localArena .allocate (C_POINTER );
275
251
long prefilterLen = prefilterShape [0 ];
276
252
prefilterBytes = C_INT_BYTE_SIZE * prefilterLen ;
277
253
278
- returnValue = cuvsRMMAlloc (cuvsResources , prefilterD , prefilterBytes );
279
- checkCuVSError (returnValue , "cuvsRMMAlloc" );
280
- prefilterDP = prefilterD .get (C_POINTER , 0 );
254
+ prefilterDP = allocateRMMSegment (cuvsResources , prefilterBytes );
281
255
282
256
cudaMemcpy (prefilterDP , prefilterDataMemorySegment , prefilterBytes , HOST_TO_DEVICE );
283
257
284
- prefilterTensor = prepareTensor (arena , prefilterDP , prefilterShape , 1 , 32 , 1 , 2 , 1 );
258
+ prefilterTensor = prepareTensor (localArena , prefilterDP , prefilterShape , 1 , 32 , 1 , 2 , 1 );
285
259
286
260
cuvsFilter .type (prefilter , 2 );
287
261
cuvsFilter .addr (prefilter , prefilterTensor .address ());
288
262
}
289
263
290
- returnValue = cuvsStreamSync (cuvsResources );
264
+ var returnValue = cuvsStreamSync (cuvsResources );
291
265
checkCuVSError (returnValue , "cuvsStreamSync" );
292
266
293
267
returnValue =
@@ -361,13 +335,12 @@ private static MemorySegment createBruteForceIndex() {
361
335
try (var localArena = Arena .ofConfined ()) {
362
336
MemorySegment indexPtrPtr = localArena .allocate (cuvsBruteForceIndex_t );
363
337
// cuvsBruteForceIndexCreate gets a pointer to a cuvsBruteForceIndex_t, which is defined as a
364
- // pointer to
365
- // cuvsBruteForceIndex.
366
- // It's basically a "out" parameter: the C functions will create the index and "return back" a
367
- // pointer to it.
338
+ // pointer to cuvsBruteForceIndex.
339
+ // It's basically an "out" parameter: the C functions will create the index and "return back"
340
+ // a pointer to it.
368
341
// The "out parameter" pointer is needed only for the duration of the function invocation (it
369
- // could be a stack
370
- // pointer, in C) so we allocate it from our localArena, unwrap it and return it.
342
+ // could be a stack pointer, in C) so we allocate it from our localArena, unwrap it and return
343
+ // it.
371
344
var returnValue = cuvsBruteForceIndexCreate (indexPtrPtr );
372
345
checkCuVSError (returnValue , "cuvsBruteForceIndexCreate" );
373
346
return indexPtrPtr .get (cuvsBruteForceIndex_t , 0 );
@@ -498,23 +471,31 @@ public BruteForceIndexImpl build() throws Throwable {
498
471
}
499
472
500
473
/**
501
- * Holds the memory reference to a BRUTEFORCE index and its associated dataset
474
+ * Holds the memory reference to a BRUTEFORCE index, its associated dataset, and the arena used to allocate
475
+ * input data
502
476
*/
503
477
private static class IndexReference {
504
478
505
479
private final MemorySegment datasetPtr ;
506
480
private final long datasetBytes ;
481
+ private final Arena tensorDataArena ;
507
482
private final MemorySegment indexPtr ;
508
483
509
- private IndexReference (MemorySegment datasetPtr , long datasetBytes , MemorySegment indexPtr ) {
484
+ private IndexReference (
485
+ MemorySegment datasetPtr ,
486
+ long datasetBytes ,
487
+ Arena tensorDataArena ,
488
+ MemorySegment indexPtr ) {
510
489
this .datasetPtr = datasetPtr ;
511
490
this .datasetBytes = datasetBytes ;
491
+ this .tensorDataArena = tensorDataArena ;
512
492
this .indexPtr = indexPtr ;
513
493
}
514
494
515
495
private IndexReference (MemorySegment indexPtr ) {
516
496
this .datasetPtr = MemorySegment .NULL ;
517
497
this .datasetBytes = 0 ;
498
+ this .tensorDataArena = null ;
518
499
this .indexPtr = indexPtr ;
519
500
}
520
501
}
0 commit comments