Skip to content

Commit 60e71ee

Browse files
pickle SVM, centralize queue_for_pickling, use Buffer/SVM pickling for Arrays
1 parent 490bb94 commit 60e71ee

File tree

3 files changed

+62
-63
lines changed

3 files changed

+62
-63
lines changed

pyopencl/__init__.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,9 +2414,10 @@ def fsvm_empty_like(ctx, ary, alignment=None):
24142414

24152415

24162416
@contextmanager
2417-
def queue_for_pickling(queue):
2417+
def queue_for_pickling(queue, alloc=None):
24182418
r"""A context manager that, for the current thread, sets the command queue
2419-
to be used for pickling and unpickling :class:`Buffer`\ s to *queue*."""
2419+
to be used for pickling and unpickling :class:`Array`\ s and :class:`Buffer`\ s
2420+
to *queue*."""
24202421
try:
24212422
existing_pickle_queue = _QUEUE_FOR_PICKLING_TLS.queue
24222423
except AttributeError:
@@ -2427,10 +2428,12 @@ def queue_for_pickling(queue):
24272428
"inside the context of its own invocation.")
24282429

24292430
_QUEUE_FOR_PICKLING_TLS.queue = queue
2431+
_QUEUE_FOR_PICKLING_TLS.alloc = alloc
24302432
try:
24312433
yield None
24322434
finally:
24332435
_QUEUE_FOR_PICKLING_TLS.queue = None
2436+
_QUEUE_FOR_PICKLING_TLS.alloc = None
24342437

24352438

24362439
def _getstate_buffer(self):
@@ -2478,6 +2481,51 @@ def _setstate_buffer(self, state):
24782481
Buffer.__getstate__ = _getstate_buffer
24792482
Buffer.__setstate__ = _setstate_buffer
24802483

2484+
if get_cl_header_version() >= (2, 0):
2485+
def _getstate_svm(self):
2486+
import pyopencl as cl
2487+
2488+
state = {}
2489+
state["size"] = self.size
2490+
2491+
try:
2492+
queue = _QUEUE_FOR_PICKLING_TLS.queue
2493+
except AttributeError:
2494+
queue = None
2495+
2496+
if queue is None:
2497+
raise RuntimeError(f"{self.__class__.__name__} instances can only be "
2498+
"pickled while queue_for_pickling is active.")
2499+
2500+
a = bytearray(self.size)
2501+
cl.enqueue_copy(queue, a, self)
2502+
2503+
state["_pickle_data"] = a
2504+
2505+
return state
2506+
2507+
def _setstate_svm(self, state):
2508+
import pyopencl as cl
2509+
2510+
try:
2511+
queue = _QUEUE_FOR_PICKLING_TLS.queue
2512+
except AttributeError:
2513+
queue = None
2514+
2515+
if queue is None:
2516+
raise RuntimeError(f"{self.__class__.__name__} instances can only be "
2517+
"unpickled while queue_for_pickling is active.")
2518+
2519+
size = state["size"]
2520+
2521+
a = state["_pickle_data"]
2522+
SVMAllocation.__init__(self, queue.context, size, alignment=0, flags=0,
2523+
queue=queue)
2524+
cl.enqueue_copy(queue, self, a)
2525+
2526+
SVMAllocation.__getstate__ = _getstate_svm
2527+
SVMAllocation.__setstate__ = _setstate_svm
2528+
24812529
# }}}
24822530

24832531
# vim: foldmethod=marker

pyopencl/array.py

Lines changed: 6 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -342,39 +342,6 @@ class _copy_queue: # noqa: N801
342342
_NOT_PRESENT = object()
343343

344344

345-
# {{{ pickling support
346-
347-
import threading
348-
from contextlib import contextmanager
349-
350-
351-
_QUEUE_FOR_PICKLING_TLS = threading.local()
352-
353-
354-
@contextmanager
355-
def queue_for_pickling(queue, alloc=None):
356-
r"""A context manager that, for the current thread, sets the command queue
357-
to be used for pickling and unpickling :class:`Array`\ s to *queue*."""
358-
try:
359-
existing_pickle_queue = _QUEUE_FOR_PICKLING_TLS.queue
360-
except AttributeError:
361-
existing_pickle_queue = None
362-
363-
if existing_pickle_queue is not None:
364-
raise RuntimeError("queue_for_pickling should not be called "
365-
"inside the context of its own invocation.")
366-
367-
_QUEUE_FOR_PICKLING_TLS.queue = queue
368-
_QUEUE_FOR_PICKLING_TLS.alloc = alloc
369-
try:
370-
yield None
371-
finally:
372-
_QUEUE_FOR_PICKLING_TLS.queue = None
373-
_QUEUE_FOR_PICKLING_TLS.alloc = None
374-
375-
# }}}
376-
377-
378345
class Array:
379346
"""A :class:`numpy.ndarray` work-alike that stores its data and performs
380347
its computations on the compute device. :attr:`shape` and :attr:`dtype` work
@@ -742,36 +709,33 @@ def __init__(
742709

743710
def __getstate__(self):
744711
try:
745-
queue = _QUEUE_FOR_PICKLING_TLS.queue
712+
queue = cl._QUEUE_FOR_PICKLING_TLS.queue
746713
except AttributeError:
747714
queue = None
748715

749716
if queue is None:
750717
raise RuntimeError("CL Array instances can only be pickled while "
751-
"queue_for_pickling is active.")
718+
"cl.queue_for_pickling is active.")
752719

753720
state = self.__dict__.copy()
754721

755722
del state["allocator"]
756723
del state["context"]
757724
del state["events"]
758725
del state["queue"]
759-
del state["base_data"]
760-
state["data"] = self.get(queue=queue)
761-
762726
return state
763727

764728
def __setstate__(self, state):
765729
try:
766-
queue = _QUEUE_FOR_PICKLING_TLS.queue
767-
alloc = _QUEUE_FOR_PICKLING_TLS.alloc
730+
queue = cl._QUEUE_FOR_PICKLING_TLS.queue
731+
alloc = cl._QUEUE_FOR_PICKLING_TLS.alloc
768732
except AttributeError:
769733
queue = None
770734
alloc = None
771735

772736
if queue is None:
773-
raise RuntimeError("CL Array instances can only be pickled while "
774-
"queue_for_pickling is active.")
737+
raise RuntimeError("CL Array instances can only be unpickled while "
738+
"cl.queue_for_pickling is active.")
775739

776740
self.__dict__.update(state)
777741

@@ -780,20 +744,6 @@ def __setstate__(self, state):
780744
self.events = []
781745
self.queue = queue
782746

783-
if self.allocator is None:
784-
self.base_data = cl.Buffer(self.context, cl.mem_flags.READ_WRITE,
785-
self.nbytes)
786-
else:
787-
self.base_data = self.allocator(self.nbytes)
788-
789-
ary = state["data"]
790-
791-
# Mimics the stride update in _get() below
792-
if ary.strides != self.strides:
793-
ary = _as_strided(ary, strides=self.strides)
794-
795-
self.set(ary, queue=queue)
796-
797747
# }}}
798748

799749
@property

test/test_array.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2415,7 +2415,7 @@ def test_array_pickling(ctx_factory):
24152415
with pytest.raises(RuntimeError):
24162416
pickle.dumps(a_gpu)
24172417

2418-
with cl_array.queue_for_pickling(queue):
2418+
with cl.queue_for_pickling(queue):
24192419
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))
24202420
assert np.all(a_gpu_pickled.get() == a)
24212421

@@ -2424,7 +2424,7 @@ def test_array_pickling(ctx_factory):
24242424
a_gpu_tagged = TaggableCLArray(queue, a.shape, a.dtype, tags={"foo", "bar"})
24252425
a_gpu_tagged.set(a)
24262426

2427-
with cl_array.queue_for_pickling(queue):
2427+
with cl.queue_for_pickling(queue):
24282428
a_gpu_tagged_pickled = pickle.loads(pickle.dumps(a_gpu_tagged))
24292429

24302430
assert np.all(a_gpu_tagged_pickled.get() == a)
@@ -2437,14 +2437,15 @@ def test_array_pickling(ctx_factory):
24372437
from pyopencl.characterize import has_coarse_grain_buffer_svm
24382438

24392439
if has_coarse_grain_buffer_svm(queue.device):
2440-
from pyopencl.tools import SVMAllocator, SVMPool
2440+
from pyopencl.tools import SVMAllocator
24412441

24422442
alloc = SVMAllocator(context, alignment=0, queue=queue)
2443-
alloc = SVMPool(alloc)
2443+
# FIXME: SVMPool is not picklable
2444+
# alloc = SVMPool(alloc)
24442445

24452446
a_dev = cl_array.to_device(queue, a, allocator=alloc)
24462447

2447-
with cl_array.queue_for_pickling(queue, alloc):
2448+
with cl.queue_for_pickling(queue, alloc):
24482449
a_dev_pickled = pickle.loads(pickle.dumps(a_dev))
24492450

24502451
assert np.all(a_dev_pickled.get() == a)

0 commit comments

Comments
 (0)