Skip to content

Commit 54ae46a

Browse files
authored
[Java] One PinnedMemoryBuffer per CuVSResourcesImpl (#1441)
While profiling cuvs-java, we found that allocating a `PinnedMemoryBuffer` for each host->device or device->host memory copy was unnecessary and wasteful. This PR moves the allocation of a `PinnedMemoryBuffer` to `CuVSResourcesImpl`, so that the buffer can be cached and reused. Since `CuVSResources` are already meant to be per-thread, this is safe, as the `PinnedMemoryBuffer` will never be used concurrently. In order to do it cleanly, we introduced two named `ScopedAccess` classes and a helper method that will always find its way to the internal `MemorySegment` used by native functions to access the buffer, without the need to expose it via the public interface. Authors: - Lorenzo Dematté (https://github.com/ldematte) - Ben Frederickson (https://github.com/benfred) - MithunR (https://github.com/mythrocks) Approvers: - MithunR (https://github.com/mythrocks) URL: #1441
1 parent 94ea498 commit 54ae46a

File tree

11 files changed

+284
-179
lines changed

11 files changed

+284
-179
lines changed

java/benchmarks/src/main/java/com/nvidia/cuvs/CuVSDeviceMatrixBenchmarks.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.Random;
1212

1313
@Fork(value = 1, warmups = 0)
14+
@Threads(1) // Sharing resources
1415
@State(Scope.Benchmark)
1516
public class CuVSDeviceMatrixBenchmarks {
1617

@@ -82,15 +83,13 @@ public void matrixCopyDeviceToHost() {
8283

8384
@Benchmark
8485
public void matrixDeviceBuilder() throws Throwable {
85-
try (CuVSResources resources = CuVSResources.create()) {
86-
var builder = CuVSMatrix.deviceBuilder(resources, size, dims, CuVSMatrix.DataType.FLOAT);
86+
var builder = CuVSMatrix.deviceBuilder(resources, size, dims, CuVSMatrix.DataType.FLOAT);
8787

88-
for (int i = 0; i < size; ++i) {
89-
var array = data[i];
90-
builder.addVector(array);
91-
}
92-
CuVSDeviceMatrix matrix = builder.build();
93-
matrix.close();
88+
for (int i = 0; i < size; ++i) {
89+
var array = data[i];
90+
builder.addVector(array);
9491
}
92+
CuVSDeviceMatrix matrix = builder.build();
93+
matrix.close();
9594
}
9695
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSMatrix.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,20 @@
1717
public interface CuVSMatrix extends AutoCloseable {
1818

1919
enum DataType {
20-
FLOAT,
21-
INT,
22-
UINT,
23-
BYTE
20+
FLOAT(4),
21+
INT(4),
22+
UINT(4),
23+
BYTE(1);
24+
25+
private final int bytes;
26+
27+
DataType(int bytes) {
28+
this.bytes = bytes;
29+
}
30+
31+
public int bytes() {
32+
return bytes;
33+
}
2434
}
2535

2636
enum MemoryKind {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package com.nvidia.cuvs;
6+
7+
public final class DelegatingScopedAccess implements CuVSResources.ScopedAccess {
8+
private final CuVSResources.ScopedAccess inner;
9+
private final Runnable closeAction;
10+
11+
DelegatingScopedAccess(CuVSResources.ScopedAccess inner, Runnable closeAction) {
12+
this.inner = inner;
13+
this.closeAction = closeAction;
14+
}
15+
16+
public CuVSResources.ScopedAccess inner() {
17+
return inner;
18+
}
19+
20+
@Override
21+
public long handle() {
22+
return inner.handle();
23+
}
24+
25+
@Override
26+
public void close() {
27+
closeAction.run();
28+
}
29+
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/SynchronizedCuVSResources.java

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,7 @@ static CuVSResources create() throws Throwable {
2727
@Override
2828
public ScopedAccess access() {
2929
lock.lock();
30-
return new ScopedAccess() {
31-
@Override
32-
public long handle() {
33-
return inner.access().handle();
34-
}
35-
36-
@Override
37-
public void close() {
38-
lock.unlock();
39-
}
40-
};
30+
return new DelegatingScopedAccess(inner.access(), lock::unlock);
4131
}
4232

4333
@Override

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSDeviceMatrixImpl.java

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,55 @@
1515

1616
public class CuVSDeviceMatrixImpl extends CuVSMatrixBaseImpl implements CuVSDeviceMatrix {
1717

18-
private long bufferedMatrixRowStart = 0;
19-
private long bufferedMatrixRowEnd = 0;
18+
private interface RowAccessStrategy {
19+
RowView getRow(long row);
20+
}
2021

2122
private final CuVSResources resources;
2223

2324
private final long rowStride;
2425
private final long columnStride;
26+
private final long rowSize;
27+
private final long valueByteSize;
28+
29+
private final RowAccessStrategy rowAccessStrategy;
2530

26-
private final PinnedMemoryBuffer hostBuffer;
31+
private class BufferedRowAccessStrategy implements RowAccessStrategy {
32+
private long bufferedMatrixRowStart = 0;
33+
private long bufferedMatrixRowEnd = 0;
34+
35+
@Override
36+
public RowView getRow(long row) {
37+
try (var access = resources.access()) {
38+
var hostBuffer = CuVSResourcesImpl.getHostBuffer(access);
39+
long rowBytes = columns * valueByteSize;
40+
if (row < bufferedMatrixRowStart || row >= bufferedMatrixRowEnd) {
41+
var endRow = Math.min(row + (PinnedMemoryBuffer.CHUNK_BYTES / rowBytes), size);
42+
populateBuffer(access, row, endRow, hostBuffer);
43+
bufferedMatrixRowStart = row;
44+
bufferedMatrixRowEnd = endRow;
45+
}
46+
var startRow = row - bufferedMatrixRowStart;
47+
return new SliceRowView(
48+
hostBuffer.asSlice(startRow * rowSize, rowBytes),
49+
columns,
50+
valueLayout,
51+
dataType,
52+
valueByteSize);
53+
}
54+
}
55+
}
56+
57+
private class DirectRowAccessStrategy implements RowAccessStrategy {
58+
@Override
59+
public RowView getRow(long row) {
60+
try (var access = resources.access()) {
61+
var memorySegment = Arena.ofAuto().allocate(size * valueByteSize);
62+
populateBuffer(access, row, row + 1, memorySegment);
63+
return new SliceRowView(memorySegment, columns, valueLayout, dataType, valueByteSize);
64+
}
65+
}
66+
}
2767

2868
protected CuVSDeviceMatrixImpl(
2969
CuVSResources resources,
@@ -48,7 +88,15 @@ protected CuVSDeviceMatrixImpl(
4888
this.resources = resources;
4989
this.rowStride = rowStride;
5090
this.columnStride = columnStride;
51-
this.hostBuffer = new PinnedMemoryBuffer(size, columns, valueLayout);
91+
92+
this.valueByteSize = valueLayout.byteSize();
93+
this.rowSize = rowStride > 0 ? rowStride * valueByteSize : columns * valueByteSize;
94+
if (rowSize > PinnedMemoryBuffer.CHUNK_BYTES) {
95+
// The shared buffer is too small for this row size, use a direct access strategy
96+
this.rowAccessStrategy = new DirectRowAccessStrategy();
97+
} else {
98+
this.rowAccessStrategy = new BufferedRowAccessStrategy();
99+
}
52100
}
53101

54102
@Override
@@ -63,15 +111,14 @@ public MemorySegment toTensor(Arena arena) {
63111
arena, memorySegment, new long[] {size, columns}, strides, code(), bits(), kDLCUDA());
64112
}
65113

66-
private void populateBuffer(long startRow) {
114+
private void populateBuffer(
115+
CuVSResources.ScopedAccess resourceAccess,
116+
long startRow,
117+
long endRow,
118+
MemorySegment bufferAddress) {
67119
try (var localArena = Arena.ofConfined()) {
68-
long rowBytes = columns * valueLayout.byteSize();
69-
var endRow = Math.min(startRow + (hostBuffer.size() / rowBytes), size);
70120
var rowCount = endRow - startRow;
71121

72-
// System.out.printf(
73-
// Locale.ROOT, "startRow: %d, endRow %d, count: %d\n", startRow, endRow, rowCount);
74-
75122
MemorySegment sliceManagedTensor = DLManagedTensor.allocate(localArena);
76123
DLManagedTensor.dl_tensor(sliceManagedTensor, DLTensor.allocate(localArena));
77124

@@ -85,40 +132,18 @@ private void populateBuffer(long startRow) {
85132

86133
MemorySegment bufferTensor =
87134
prepareTensor(
88-
localArena,
89-
hostBuffer.address(),
90-
new long[] {rowCount, columns},
91-
code(),
92-
bits(),
93-
kDLCPU());
94-
95-
try (var resourceAccess = resources.access()) {
96-
checkCuVSError(
97-
cuvsMatrixCopy(resourceAccess.handle(), sliceManagedTensor, bufferTensor),
98-
"cuvsMatrixCopy");
99-
checkCuVSError(cuvsStreamSync(resourceAccess.handle()), "cuvsStreamSync");
100-
101-
bufferedMatrixRowStart = startRow;
102-
bufferedMatrixRowEnd = endRow;
103-
}
135+
localArena, bufferAddress, new long[] {rowCount, columns}, code(), bits(), kDLCPU());
136+
137+
checkCuVSError(
138+
cuvsMatrixCopy(resourceAccess.handle(), sliceManagedTensor, bufferTensor),
139+
"cuvsMatrixCopy");
140+
checkCuVSError(cuvsStreamSync(resourceAccess.handle()), "cuvsStreamSync");
104141
}
105142
}
106143

107144
@Override
108145
public RowView getRow(long row) {
109-
if (row < bufferedMatrixRowStart || row >= bufferedMatrixRowEnd) {
110-
populateBuffer(row);
111-
}
112-
var valueByteSize = valueLayout.byteSize();
113-
var startRow = row - bufferedMatrixRowStart;
114-
115-
var rowSize = rowStride > 0 ? rowStride * valueByteSize : columns * valueByteSize;
116-
return new SliceRowView(
117-
hostBuffer.address().asSlice(startRow * rowSize, columns * valueByteSize),
118-
columns,
119-
valueLayout,
120-
dataType,
121-
valueByteSize);
146+
return rowAccessStrategy.getRow(row);
122147
}
123148

124149
@Override
@@ -199,9 +224,7 @@ public void toDevice(CuVSDeviceMatrix targetMatrix, CuVSResources cuVSResources)
199224
}
200225

201226
@Override
202-
public void close() {
203-
hostBuffer.close();
204-
}
227+
public void close() {}
205228

206229
private static class CuVSDeviceMatrixDelegate implements CuVSDeviceMatrix, CuVSMatrixInternal {
207230
private final CuVSDeviceMatrixImpl deviceMatrix;

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSResourcesImpl.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import static com.nvidia.cuvs.internal.panama.headers_h_1.C_INT;
1010

1111
import com.nvidia.cuvs.CuVSResources;
12+
import com.nvidia.cuvs.DelegatingScopedAccess;
13+
import com.nvidia.cuvs.internal.common.PinnedMemoryBuffer;
1214
import java.lang.foreign.Arena;
15+
import java.lang.foreign.MemorySegment;
1316
import java.nio.file.Path;
1417

1518
/**
@@ -24,6 +27,8 @@ public class CuVSResourcesImpl implements CuVSResources {
2427
private final ScopedAccess access;
2528
private final int deviceId;
2629

30+
private final PinnedMemoryBuffer hostBuffer = new PinnedMemoryBuffer();
31+
2732
/**
2833
* Constructor that allocates the resources needed for cuVS
2934
*
@@ -37,16 +42,7 @@ public CuVSResourcesImpl(Path tempDirectory) {
3742
var deviceIdPtr = localArena.allocate(C_INT);
3843
checkCuVSError(cuvsDeviceIdGet(resourceHandle, deviceIdPtr), "cuvsDeviceIdGet");
3944
this.deviceId = deviceIdPtr.get(C_INT, 0);
40-
this.access =
41-
new ScopedAccess() {
42-
@Override
43-
public long handle() {
44-
return resourceHandle;
45-
}
46-
47-
@Override
48-
public void close() {}
49-
};
45+
this.access = new ScopedAccessWithHostBuffer(resourceHandle, hostBuffer.address());
5046
}
5147
}
5248

@@ -65,11 +61,25 @@ public void close() {
6561
synchronized (this) {
6662
int returnValue = cuvsResourcesDestroy(resourceHandle);
6763
checkCuVSError(returnValue, "cuvsResourcesDestroy");
64+
hostBuffer.close();
6865
}
6966
}
7067

7168
@Override
7269
public Path tempDirectory() {
7370
return tempDirectory;
7471
}
72+
73+
public static MemorySegment getHostBuffer(ScopedAccess access) {
74+
75+
while (access instanceof DelegatingScopedAccess delegatingScopedAccess) {
76+
access = delegatingScopedAccess.inner();
77+
}
78+
79+
if (access instanceof ScopedAccessWithHostBuffer withHostBuffer) {
80+
return withHostBuffer.hostBuffer();
81+
}
82+
83+
throw new IllegalArgumentException("Unsupported access type: " + access.getClass().getName());
84+
}
7585
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package com.nvidia.cuvs.internal;
6+
7+
import com.nvidia.cuvs.CuVSResources;
8+
import java.lang.foreign.MemorySegment;
9+
10+
class ScopedAccessWithHostBuffer implements CuVSResources.ScopedAccess {
11+
private final long resourceHandle;
12+
private final MemorySegment hostBuffer;
13+
14+
public ScopedAccessWithHostBuffer(long resourceHandle, MemorySegment hostBuffer) {
15+
this.resourceHandle = resourceHandle;
16+
this.hostBuffer = hostBuffer;
17+
}
18+
19+
@Override
20+
public long handle() {
21+
return resourceHandle;
22+
}
23+
24+
public MemorySegment hostBuffer() {
25+
return hostBuffer;
26+
}
27+
28+
@Override
29+
public void close() {}
30+
}

0 commit comments

Comments
 (0)