Skip to content

Commit 00876e1

Browse files
add _get_queue_for_pickling, outline some pool support
1 parent 001615c commit 00876e1

File tree

2 files changed

+90
-43
lines changed

2 files changed

+90
-43
lines changed

pyopencl/__init__.py

Lines changed: 79 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
MemoryObject,
138138
MemoryMap,
139139
Buffer,
140+
PooledBuffer,
140141

141142
_Program,
142143
Kernel,
@@ -197,7 +198,7 @@
197198
enqueue_migrate_mem_objects, unload_platform_compiler)
198199

199200
if get_cl_header_version() >= (2, 0):
200-
from pyopencl._cl import SVM, SVMAllocation, SVMPointer
201+
from pyopencl._cl import SVM, SVMAllocation, SVMPointer, PooledSVM
201202

202203
if _cl.have_gl():
203204
from pyopencl._cl import ( # noqa: F401
@@ -2439,21 +2440,28 @@ def queue_for_pickling(queue, alloc=None):
24392440
_QUEUE_FOR_PICKLING_TLS.alloc = None
24402441

24412442

2442-
def _getstate_buffer(self):
2443-
import pyopencl as cl
2444-
state = {}
2445-
state["size"] = self.size
2446-
state["flags"] = self.flags
2447-
2443+
def _get_queue_for_pickling(obj):
24482444
try:
24492445
queue = _QUEUE_FOR_PICKLING_TLS.queue
2446+
alloc = _QUEUE_FOR_PICKLING_TLS.alloc
24502447
except AttributeError:
24512448
queue = None
24522449

24532450
if queue is None:
2454-
raise RuntimeError("CL Buffer instances can only be pickled while "
2451+
raise RuntimeError(f"{type(obj).__name__} instances can only be pickled while "
24552452
"queue_for_pickling is active.")
24562453

2454+
return queue, alloc
2455+
2456+
2457+
def _getstate_buffer(self):
2458+
import pyopencl as cl
2459+
queue, _alloc = _get_queue_for_pickling(self)
2460+
2461+
state = {}
2462+
state["size"] = self.size
2463+
state["flags"] = self.flags
2464+
24572465
a = bytearray(self.size)
24582466
cl.enqueue_copy(queue, a, self)
24592467

@@ -2463,42 +2471,57 @@ def _getstate_buffer(self):
24632471

24642472

24652473
def _setstate_buffer(self, state):
2466-
try:
2467-
queue = _QUEUE_FOR_PICKLING_TLS.queue
2468-
except AttributeError:
2469-
queue = None
2470-
2471-
if queue is None:
2472-
raise RuntimeError("CL Buffer instances can only be unpickled while "
2473-
"queue_for_pickling is active.")
2474+
import pyopencl as cl
2475+
queue, _alloc = _get_queue_for_pickling(self)
24742476

24752477
size = state["size"]
24762478
flags = state["flags"]
24772479

2478-
import pyopencl as cl
2479-
24802480
a = state["_pickle_data"]
24812481
Buffer.__init__(self, queue.context, flags | cl.mem_flags.COPY_HOST_PTR, size, a)
24822482

24832483

24842484
Buffer.__getstate__ = _getstate_buffer
24852485
Buffer.__setstate__ = _setstate_buffer
24862486

2487+
2488+
def _getstate_pooledbuffer(self):
2489+
import pyopencl as cl
2490+
queue, _alloc = _get_queue_for_pickling(self)
2491+
2492+
state = {}
2493+
state["size"] = self.size
2494+
state["flags"] = self.flags
2495+
2496+
a = bytearray(self.size)
2497+
cl.enqueue_copy(queue, a, self)
2498+
state["_pickle_data"] = a
2499+
2500+
return state
2501+
2502+
2503+
def _setstate_pooledbuffer(self, state):
2504+
_queue, _alloc = _get_queue_for_pickling(self)
2505+
2506+
_size = state["size"]
2507+
_flags = state["flags"]
2508+
2509+
_a = state["_pickle_data"]
2510+
# FIXME: Unclear what to do here - PooledBuffer does not have __init__
2511+
2512+
2513+
PooledBuffer.__getstate__ = _getstate_pooledbuffer
2514+
PooledBuffer.__setstate__ = _setstate_pooledbuffer
2515+
2516+
24872517
if get_cl_header_version() >= (2, 0):
2488-
def _getstate_svm(self):
2518+
def _getstate_svmallocation(self):
24892519
import pyopencl as cl
24902520

24912521
state = {}
24922522
state["size"] = self.size
24932523

2494-
try:
2495-
queue = _QUEUE_FOR_PICKLING_TLS.queue
2496-
except AttributeError:
2497-
queue = None
2498-
2499-
if queue is None:
2500-
raise RuntimeError(f"{self.__class__.__name__} instances can only be "
2501-
"pickled while queue_for_pickling is active.")
2524+
queue, _alloc = _get_queue_for_pickling(self)
25022525

25032526
a = bytearray(self.size)
25042527
cl.enqueue_copy(queue, a, self)
@@ -2507,17 +2530,10 @@ def _getstate_svm(self):
25072530

25082531
return state
25092532

2510-
def _setstate_svm(self, state):
2533+
def _setstate_svmallocation(self, state):
25112534
import pyopencl as cl
25122535

2513-
try:
2514-
queue = _QUEUE_FOR_PICKLING_TLS.queue
2515-
except AttributeError:
2516-
queue = None
2517-
2518-
if queue is None:
2519-
raise RuntimeError(f"{self.__class__.__name__} instances can only be "
2520-
"unpickled while queue_for_pickling is active.")
2536+
queue, _alloc = _get_queue_for_pickling(self)
25212537

25222538
size = state["size"]
25232539

@@ -2526,8 +2542,33 @@ def _setstate_svm(self, state):
25262542
queue=queue)
25272543
cl.enqueue_copy(queue, self, a)
25282544

2529-
SVMAllocation.__getstate__ = _getstate_svm
2530-
SVMAllocation.__setstate__ = _setstate_svm
2545+
SVMAllocation.__getstate__ = _getstate_svmallocation
2546+
SVMAllocation.__setstate__ = _setstate_svmallocation
2547+
2548+
def _getstate_pooled_svm(self):
2549+
import pyopencl as cl
2550+
2551+
state = {}
2552+
state["size"] = self.size
2553+
2554+
queue, _alloc = _get_queue_for_pickling(self)
2555+
2556+
a = bytearray(self.size)
2557+
cl.enqueue_copy(queue, a, self)
2558+
2559+
state["_pickle_data"] = a
2560+
2561+
return state
2562+
2563+
def _setstate_pooled_svm(self, state):
2564+
_queue, _alloc = _get_queue_for_pickling(self)
2565+
_size = state["size"]
2566+
_data = state["_pickle_data"]
2567+
2568+
# FIXME: Unclear what to do here - PooledSVM does not have __init__
2569+
2570+
PooledSVM.__getstate__ = _getstate_pooled_svm
2571+
PooledSVM.__setstate__ = _setstate_pooled_svm
25312572

25322573
# }}}
25332574

test/test_array.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,12 +2404,18 @@ def __init__(self, cq, shape, dtype, tags):
24042404
self.tags = tags
24052405

24062406

2407-
def test_array_pickling(ctx_factory):
2407+
@pytest.mark.parametrize("use_mempool", [False, True])
2408+
def test_array_pickling(ctx_factory, use_mempool):
24082409
context = ctx_factory()
24092410
queue = cl.CommandQueue(context)
24102411

2412+
if use_mempool:
2413+
alloc = cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue))
2414+
else:
2415+
alloc = None
2416+
24112417
a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
2412-
a_gpu = cl_array.to_device(queue, a)
2418+
a_gpu = cl_array.to_device(queue, a, allocator=alloc)
24132419

24142420
import pickle
24152421
with pytest.raises(RuntimeError):
@@ -2437,11 +2443,11 @@ def test_array_pickling(ctx_factory):
24372443
from pyopencl.characterize import has_coarse_grain_buffer_svm
24382444

24392445
if has_coarse_grain_buffer_svm(queue.device):
2440-
from pyopencl.tools import SVMAllocator
2446+
from pyopencl.tools import SVMAllocator, SVMPool
24412447

24422448
alloc = SVMAllocator(context, alignment=0, queue=queue)
2443-
# FIXME: SVMPool is not picklable
2444-
# alloc = SVMPool(alloc)
2449+
if use_mempool:
2450+
alloc = SVMPool(alloc)
24452451

24462452
a_dev = cl_array.to_device(queue, a, allocator=alloc)
24472453

0 commit comments

Comments
 (0)