Skip to content

Commit 0dc8f93

Browse files
committed
test_cupy: test when CuPy isn't available
1 parent 440753c commit 0dc8f93

File tree

1 file changed

+80
-40
lines changed

1 file changed

+80
-40
lines changed

zarr/tests/test_cupy.py

Lines changed: 80 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from zarr.storage import DirectoryStore, MemoryStore, Store, ZipStore
1414

1515

16-
cupy = pytest.importorskip("cupy")
17-
18-
1916
class CuPyCPUCompressor(Codec):
2017
"""CPU compressor for CuPy arrays
2118
@@ -71,6 +68,28 @@ def from_config(cls, config):
7168
register_codec(CuPyCPUCompressor)
7269

7370

71+
class MyArray(np.ndarray):
72+
"""Dummy array class to test the `meta_array` argument
73+
74+
Useful when CuPy isn't available.
75+
76+
This class also makes some of the functions from the numpy
77+
module available.
78+
"""
79+
80+
testing = np.testing
81+
82+
@classmethod
83+
def arange(cls, size):
84+
ret = cls(shape=(size,), dtype="int64")
85+
ret[:] = range(size)
86+
return ret
87+
88+
@classmethod
89+
def empty(cls, shape):
90+
return cls(shape=shape)
91+
92+
7493
def init_compressor(compressor) -> CuPyCPUCompressor:
7594
if compressor:
7695
compressor = getattr(zarr.codecs, compressor)()
@@ -85,93 +104,113 @@ def init_store(tmp_path, store_type) -> Optional[Store]:
85104
return None
86105

87106

88-
@pytest.mark.parametrize("compressor", [None, "Zlib", "Blosc"])
107+
def ensure_cls(obj):
108+
if isinstance(obj, str):
109+
module, cls_name = obj.rsplit(".", maxsplit=1)
110+
return getattr(pytest.importorskip(module), cls_name)
111+
return obj
112+
113+
114+
def ensure_module(module):
115+
if isinstance(module, str):
116+
return pytest.importorskip(module)
117+
return module
118+
119+
120+
param_module_and_compressor = [
121+
(MyArray, None),
122+
("cupy", init_compressor(None)),
123+
("cupy", init_compressor("Zlib")),
124+
("cupy", init_compressor("Blosc")),
125+
]
126+
127+
128+
@pytest.mark.parametrize("module, compressor", param_module_and_compressor)
89129
@pytest.mark.parametrize("store_type", [None, DirectoryStore, MemoryStore, ZipStore])
90-
def test_array(tmp_path, compressor, store_type):
91-
compressor = init_compressor(compressor)
130+
def test_array(tmp_path, module, compressor, store_type):
131+
xp = ensure_module(module)
92132

93-
# with cupy array
94133
store = init_store(tmp_path / "from_cupy_array", store_type)
95-
a = cupy.arange(100)
96-
z = array(
97-
a, chunks=10, compressor=compressor, store=store, meta_array=cupy.empty(())
98-
)
134+
a = xp.arange(100)
135+
z = array(a, chunks=10, compressor=compressor, store=store, meta_array=xp.empty(()))
99136
assert a.shape == z.shape
100137
assert a.dtype == z.dtype
101138
assert isinstance(a, type(z[:]))
102-
assert isinstance(z.meta_array, type(cupy.empty(())))
103-
cupy.testing.assert_array_equal(a, z[:])
139+
assert isinstance(z.meta_array, type(xp.empty(())))
140+
xp.testing.assert_array_equal(a, z[:])
104141

105142
# with array-like
106143
store = init_store(tmp_path / "from_list", store_type)
107144
a = list(range(100))
108-
z = array(
109-
a, chunks=10, compressor=compressor, store=store, meta_array=cupy.empty(())
110-
)
145+
z = array(a, chunks=10, compressor=compressor, store=store, meta_array=xp.empty(()))
111146
assert (100,) == z.shape
112147
assert np.asarray(a).dtype == z.dtype
113-
cupy.testing.assert_array_equal(a, z[:])
148+
xp.testing.assert_array_equal(a, z[:])
114149

115150
# with another zarr array
116151
store = init_store(tmp_path / "from_another_store", store_type)
117-
z2 = array(z, compressor=compressor, store=store, meta_array=cupy.empty(()))
152+
z2 = array(z, compressor=compressor, store=store, meta_array=xp.empty(()))
118153
assert z.shape == z2.shape
119154
assert z.chunks == z2.chunks
120155
assert z.dtype == z2.dtype
121-
cupy.testing.assert_array_equal(z[:], z2[:])
156+
xp.testing.assert_array_equal(z[:], z2[:])
122157

123158

124-
@pytest.mark.parametrize("compressor", [None, "Zlib", "Blosc"])
125-
def test_empty(compressor):
159+
@pytest.mark.parametrize("module, compressor", param_module_and_compressor)
160+
def test_empty(module, compressor):
161+
xp = ensure_module(module)
126162
z = empty(
127163
100,
128164
chunks=10,
129165
compressor=init_compressor(compressor),
130-
meta_array=cupy.empty(()),
166+
meta_array=xp.empty(()),
131167
)
132168
assert (100,) == z.shape
133169
assert (10,) == z.chunks
134170

135171

136-
@pytest.mark.parametrize("compressor", [None, "Zlib", "Blosc"])
137-
def test_zeros(compressor):
172+
@pytest.mark.parametrize("module, compressor", param_module_and_compressor)
173+
def test_zeros(module, compressor):
174+
xp = ensure_module(module)
138175
z = zeros(
139176
100,
140177
chunks=10,
141178
compressor=init_compressor(compressor),
142-
meta_array=cupy.empty(()),
179+
meta_array=xp.empty(()),
143180
)
144181
assert (100,) == z.shape
145182
assert (10,) == z.chunks
146-
cupy.testing.assert_array_equal(np.zeros(100), z[:])
183+
xp.testing.assert_array_equal(np.zeros(100), z[:])
147184

148185

149-
@pytest.mark.parametrize("compressor", [None, "Zlib", "Blosc"])
150-
def test_ones(compressor):
186+
@pytest.mark.parametrize("module, compressor", param_module_and_compressor)
187+
def test_ones(module, compressor):
188+
xp = ensure_module(module)
151189
z = ones(
152190
100,
153191
chunks=10,
154192
compressor=init_compressor(compressor),
155-
meta_array=cupy.empty(()),
193+
meta_array=xp.empty(()),
156194
)
157195
assert (100,) == z.shape
158196
assert (10,) == z.chunks
159-
cupy.testing.assert_array_equal(np.ones(100), z[:])
197+
xp.testing.assert_array_equal(np.ones(100), z[:])
160198

161199

162-
@pytest.mark.parametrize("compressor", [None, "Zlib", "Blosc"])
163-
def test_full(compressor):
200+
@pytest.mark.parametrize("module, compressor", param_module_and_compressor)
201+
def test_full(module, compressor):
202+
xp = ensure_module(module)
164203
z = full(
165204
100,
166205
chunks=10,
167206
fill_value=42,
168207
dtype="i4",
169208
compressor=init_compressor(compressor),
170-
meta_array=cupy.empty(()),
209+
meta_array=xp.empty(()),
171210
)
172211
assert (100,) == z.shape
173212
assert (10,) == z.chunks
174-
cupy.testing.assert_array_equal(np.full(100, fill_value=42, dtype="i4"), z[:])
213+
xp.testing.assert_array_equal(np.full(100, fill_value=42, dtype="i4"), z[:])
175214

176215
# nan
177216
z = full(
@@ -180,21 +219,22 @@ def test_full(compressor):
180219
fill_value=np.nan,
181220
dtype="f8",
182221
compressor=init_compressor(compressor),
183-
meta_array=cupy.empty(()),
222+
meta_array=xp.empty(()),
184223
)
185224
assert np.all(np.isnan(z[:]))
186225

187226

188-
@pytest.mark.parametrize("compressor", [None, "Zlib", "Blosc"])
227+
@pytest.mark.parametrize("module, compressor", param_module_and_compressor)
189228
@pytest.mark.parametrize("store_type", [None, DirectoryStore, MemoryStore, ZipStore])
190-
def test_group(tmp_path, compressor, store_type):
229+
def test_group(tmp_path, module, compressor, store_type):
230+
xp = ensure_module(module)
191231
store = init_store(tmp_path, store_type)
192-
g = open_group(store, meta_array=cupy.empty(()))
232+
g = open_group(store, meta_array=xp.empty(()))
193233
g.ones("data", shape=(10, 11), dtype=int, compressor=init_compressor(compressor))
194234
a = g["data"]
195235
assert a.shape == (10, 11)
196236
assert a.dtype == int
197237
assert isinstance(a, Array)
198-
assert isinstance(a[:], cupy.ndarray)
238+
assert isinstance(a[:], xp.ndarray)
199239
assert (a[:] == 1).all()
200-
assert isinstance(g.meta_array, type(cupy.empty(())))
240+
assert isinstance(g.meta_array, type(xp.empty(())))

0 commit comments

Comments
 (0)