Skip to content

Commit 6e0fe58

Browse files
authored
Improve ensure_memoryview test coverage & make minor fixes (#6333)
1 parent 50d2911 commit 6e0fe58

File tree

5 files changed

+74
-32
lines changed

5 files changed

+74
-32
lines changed

distributed/comm/asyncio_tcp.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
host_array,
2828
to_frames,
2929
)
30-
from distributed.utils import ensure_ip, get_ip, get_ipv6
30+
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6
3131

3232
logger = logging.getLogger(__name__)
3333

@@ -380,7 +380,9 @@ async def write(self, frames: list[bytes]) -> int:
380380
await drain_waiter
381381

382382
# Ensure all memoryviews are in single-byte format
383-
frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames]
383+
frames = [
384+
ensure_memoryview(f) if isinstance(f, memoryview) else f for f in frames
385+
]
384386

385387
nframes = len(frames)
386388
frames_nbytes = [len(f) for f in frames]
@@ -852,12 +854,9 @@ def _buffer_clear(self):
852854

853855
def _buffer_append(self, data: bytes) -> None:
854856
"""Append new data to the send buffer"""
855-
if not isinstance(data, memoryview):
856-
data = memoryview(data)
857-
if data.format != "B":
858-
data = data.cast("B")
859-
self._size += len(data)
860-
self._buffers.append(data)
857+
mv = ensure_memoryview(data)
858+
self._size += len(mv)
859+
self._buffers.append(mv)
861860

862861
def _buffer_peek(self) -> list[memoryview]:
863862
"""Get one or more buffers to write to the socket"""

distributed/comm/tcp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747
from distributed.protocol.utils import pack_frames_prelude, unpack_frames
4848
from distributed.system import MEMORY_LIMIT
49-
from distributed.utils import ensure_ip, get_ip, get_ipv6, nbytes
49+
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes
5050

5151
logger = logging.getLogger(__name__)
5252

@@ -305,7 +305,7 @@ async def write(self, msg, serializers=None, on_error="message"):
305305
if isinstance(each_frame, memoryview):
306306
# Make sure that `len(data) == data.nbytes`
307307
# See <https://github.com/tornadoweb/tornado/pull/2996>
308-
each_frame = memoryview(each_frame).cast("B")
308+
each_frame = ensure_memoryview(each_frame)
309309

310310
stream._write_buffer.append(each_frame)
311311
stream._total_write_index += each_frame_nbytes

distributed/protocol/serialize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def pickle_loads(header, frames):
8888
writeable = len(buffers) * (None,)
8989

9090
new = []
91-
memoryviews = map(memoryview, buffers)
91+
memoryviews = map(ensure_memoryview, buffers)
9292
for w, mv in zip(writeable, memoryviews):
9393
if w == mv.readonly:
9494
if w:
@@ -785,7 +785,7 @@ def _serialize_memoryview(obj):
785785
@dask_deserialize.register(memoryview)
786786
def _deserialize_memoryview(header, frames):
787787
if len(frames) == 1:
788-
out = memoryview(frames[0]).cast("B")
788+
out = ensure_memoryview(frames[0])
789789
else:
790790
out = memoryview(b"".join(frames))
791791
out = out.cast(header["format"], header["shape"])

distributed/tests/test_utils.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -248,23 +248,59 @@ def test_seek_delimiter_endline():
248248
assert f.tell() == 7
249249

250250

251-
def test_ensure_memoryview_empty():
252-
result = ensure_memoryview(b"")
251+
@pytest.mark.parametrize(
252+
"data",
253+
[
254+
b"",
255+
bytearray(),
256+
b"1",
257+
bytearray(b"1"),
258+
memoryview(b"1"),
259+
memoryview(bytearray(b"1")),
260+
array("B", b"1"),
261+
array("I", range(5)),
262+
memoryview(b"123456")[::2],
263+
memoryview(b"123456").cast("B", (2, 3)),
264+
memoryview(b"0123456789").cast("B", (5, 2))[::2],
265+
],
266+
)
267+
def test_ensure_memoryview(data):
268+
data_mv = memoryview(data)
269+
result = ensure_memoryview(data)
253270
assert isinstance(result, memoryview)
254-
assert result == memoryview(b"")
255-
256-
257-
def test_ensure_memoryview():
258-
data = [b"1", memoryview(b"1"), bytearray(b"1"), array("B", b"1")]
259-
for d in data:
260-
result = ensure_memoryview(d)
261-
assert isinstance(result, memoryview)
262-
assert result == memoryview(b"1")
263-
264-
265-
def test_ensure_memoryview_ndarray():
271+
assert result.contiguous
272+
assert result.ndim == 1
273+
assert result.format == "B"
274+
assert result == bytes(data_mv)
275+
if data_mv.nbytes and data_mv.contiguous:
276+
assert id(result.obj) == id(data_mv.obj)
277+
assert result.readonly == data_mv.readonly
278+
if isinstance(data, memoryview):
279+
if data.ndim == 1 and data.format == "B":
280+
assert id(result) == id(data)
281+
else:
282+
assert id(data) != id(result)
283+
else:
284+
assert id(result.obj) != id(data_mv.obj)
285+
assert not result.readonly
286+
287+
288+
@pytest.mark.parametrize(
289+
"dt, nitems, shape, strides",
290+
[
291+
("i8", 12, (12,), (8,)),
292+
("i8", 12, (3, 4), (32, 8)),
293+
("i8", 12, (4, 3), (8, 32)),
294+
("i8", 12, (3, 2), (32, 16)),
295+
("i8", 12, (2, 3), (16, 32)),
296+
],
297+
)
298+
def test_ensure_memoryview_ndarray(dt, nitems, shape, strides):
266299
np = pytest.importorskip("numpy")
267-
result = ensure_memoryview(np.arange(12).reshape(3, 4)[:, ::2].T)
300+
data = np.ndarray(
301+
shape, dtype=dt, buffer=np.arange(nitems, dtype=dt), strides=strides
302+
)
303+
result = ensure_memoryview(data)
268304
assert isinstance(result, memoryview)
269305
assert result.ndim == 1
270306
assert result.format == "B"

distributed/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from functools import wraps
2828
from hashlib import md5
2929
from importlib.util import cache_from_source
30+
from pickle import PickleBuffer
3031
from time import sleep
3132
from types import ModuleType
3233
from typing import TYPE_CHECKING
@@ -1021,13 +1022,19 @@ def ensure_memoryview(obj):
10211022

10221023
if not mv.nbytes:
10231024
# Drop `obj` reference to permit freeing underlying data
1024-
return memoryview(b"")
1025-
elif mv.contiguous:
1025+
return memoryview(bytearray())
1026+
elif not mv.contiguous:
1027+
# Copy to contiguous form of expected shape & type
1028+
return memoryview(bytearray(mv))
1029+
elif mv.ndim != 1 or mv.format != "B":
10261030
# Perform zero-copy reshape & cast
1027-
return mv.cast("B")
1031+
# Use `PickleBuffer.raw()` as `memoryview.cast()` fails with F-order
1032+
# Pass `mv.obj` so the created `memoryview` has that as its `obj`
1033+
# xref: https://github.com/python/cpython/issues/91484
1034+
return PickleBuffer(mv.obj).raw()
10281035
else:
1029-
# Copy to contiguous form of expected shape & type
1030-
return memoryview(mv.tobytes())
1036+
# Return `memoryview` as it already meets requirements
1037+
return mv
10311038

10321039

10331040
def open_port(host=""):

0 commit comments

Comments
 (0)