Skip to content

Commit 2609748

Browse files
authored
Fix and test sharding with GPU buffers (#2978)
1 parent 693324c commit 2609748

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

changes/2978.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed sharding with GPU buffers.

src/zarr/codecs/sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec:
683683
config=ArrayConfig(
684684
order="C", write_empty_chunks=False
685685
), # Note: this is hard-coded for simplicity -- it is not surfaced into user code,
686-
prototype=numpy_buffer_prototype(),
686+
prototype=default_buffer_prototype(),
687687
)
688688

689689
def _get_chunk_spec(self, shard_spec: ArraySpec) -> ArraySpec:

src/zarr/testing/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def has_cupy() -> bool:
4444
# Decorator for GPU tests
4545
def gpu_test(func: T_Callable) -> T_Callable:
4646
return cast(
47-
T_Callable,
47+
"T_Callable",
4848
pytest.mark.gpu(
4949
pytest.mark.skipif(not has_cupy(), reason="CuPy not installed or no GPU available")(
5050
func

tests/test_buffer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,34 @@ async def test_codecs_use_of_gpu_prototype() -> None:
148148
assert cp.array_equal(expect, got)
149149

150150

151+
@gpu_test
152+
@pytest.mark.asyncio
153+
async def test_sharding_use_of_gpu_prototype() -> None:
154+
with zarr.config.enable_gpu():
155+
expect = cp.zeros((10, 10), dtype="uint16", order="F")
156+
157+
a = await zarr.api.asynchronous.create_array(
158+
StorePath(MemoryStore()) / "test_codecs_use_of_gpu_prototype",
159+
shape=expect.shape,
160+
chunks=(5, 5),
161+
shards=(10, 10),
162+
dtype=expect.dtype,
163+
fill_value=0,
164+
)
165+
expect[:] = cp.arange(100).reshape(10, 10)
166+
167+
await a.setitem(
168+
selection=(slice(0, 10), slice(0, 10)),
169+
value=expect[:],
170+
prototype=gpu.buffer_prototype,
171+
)
172+
got = await a.getitem(
173+
selection=(slice(0, 10), slice(0, 10)), prototype=gpu.buffer_prototype
174+
)
175+
assert isinstance(got, cp.ndarray)
176+
assert cp.array_equal(expect, got)
177+
178+
151179
def test_numpy_buffer_prototype() -> None:
152180
buffer = cpu.buffer_prototype.buffer.create_zero_length()
153181
ndbuffer = cpu.buffer_prototype.nd_buffer.create(shape=(1, 2), dtype=np.dtype("int64"))
@@ -157,6 +185,16 @@ def test_numpy_buffer_prototype() -> None:
157185
ndbuffer.as_scalar()
158186

159187

188+
@gpu_test
189+
def test_gpu_buffer_prototype() -> None:
190+
buffer = gpu.buffer_prototype.buffer.create_zero_length()
191+
ndbuffer = gpu.buffer_prototype.nd_buffer.create(shape=(1, 2), dtype=cp.dtype("int64"))
192+
assert isinstance(buffer.as_array_like(), cp.ndarray)
193+
assert isinstance(ndbuffer.as_ndarray_like(), cp.ndarray)
194+
with pytest.raises(ValueError, match="Buffer does not contain a single scalar value"):
195+
ndbuffer.as_scalar()
196+
197+
160198
# TODO: the same test for other buffer classes
161199
def test_cpu_buffer_as_scalar() -> None:
162200
buf = cpu.buffer_prototype.nd_buffer.create(shape=(), dtype="int64")

0 commit comments

Comments
 (0)