Skip to content

Commit b8f6cb9

Browse files
authored
feature(threading): use explicit/configurable thread pool executor (#2327)
1 parent 3964eab commit b8f6cb9

File tree

10 files changed

+114
-41
lines changed

10 files changed

+114
-41
lines changed

src/zarr/codecs/_v2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING
56

@@ -8,7 +9,6 @@
89

910
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec
1011
from zarr.core.buffer import Buffer, NDBuffer, default_buffer_prototype
11-
from zarr.core.common import to_thread
1212
from zarr.registry import get_ndbuffer_class
1313

1414
if TYPE_CHECKING:
@@ -30,7 +30,7 @@ async def _decode_single(
3030
) -> NDBuffer:
3131
if self.compressor is not None:
3232
chunk_numpy_array = ensure_ndarray(
33-
await to_thread(self.compressor.decode, chunk_bytes.as_array_like())
33+
await asyncio.to_thread(self.compressor.decode, chunk_bytes.as_array_like())
3434
)
3535
else:
3636
chunk_numpy_array = ensure_ndarray(chunk_bytes.as_array_like())
@@ -54,7 +54,7 @@ async def _encode_single(
5454
):
5555
chunk_numpy_array = chunk_numpy_array.copy(order="A")
5656
encoded_chunk_bytes = ensure_bytes(
57-
await to_thread(self.compressor.encode, chunk_numpy_array)
57+
await asyncio.to_thread(self.compressor.encode, chunk_numpy_array)
5858
)
5959
else:
6060
encoded_chunk_bytes = ensure_bytes(chunk_numpy_array)
@@ -80,7 +80,7 @@ async def _decode_single(
8080
# apply filters in reverse order
8181
if self.filters is not None:
8282
for filter in self.filters[::-1]:
83-
chunk_ndarray = await to_thread(filter.decode, chunk_ndarray)
83+
chunk_ndarray = await asyncio.to_thread(filter.decode, chunk_ndarray)
8484

8585
# ensure correct chunk shape
8686
if chunk_ndarray.shape != chunk_spec.shape:
@@ -100,7 +100,7 @@ async def _encode_single(
100100

101101
if self.filters is not None:
102102
for filter in self.filters:
103-
chunk_ndarray = await to_thread(filter.encode, chunk_ndarray)
103+
chunk_ndarray = await asyncio.to_thread(filter.encode, chunk_ndarray)
104104

105105
return get_ndbuffer_class().from_ndarray_like(chunk_ndarray)
106106

src/zarr/codecs/blosc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from dataclasses import dataclass, replace
45
from enum import Enum
56
from functools import cached_property
@@ -10,7 +11,7 @@
1011

1112
from zarr.abc.codec import BytesBytesCodec
1213
from zarr.core.buffer.cpu import as_numpy_array_wrapper
13-
from zarr.core.common import JSON, parse_enum, parse_named_configuration, to_thread
14+
from zarr.core.common import JSON, parse_enum, parse_named_configuration
1415
from zarr.registry import register_codec
1516

1617
if TYPE_CHECKING:
@@ -169,7 +170,7 @@ async def _decode_single(
169170
chunk_bytes: Buffer,
170171
chunk_spec: ArraySpec,
171172
) -> Buffer:
172-
return await to_thread(
173+
return await asyncio.to_thread(
173174
as_numpy_array_wrapper, self._blosc_codec.decode, chunk_bytes, chunk_spec.prototype
174175
)
175176

@@ -180,7 +181,7 @@ async def _encode_single(
180181
) -> Buffer | None:
181182
# Since blosc only support host memory, we convert the input and output of the encoding
182183
# between numpy array and buffer
183-
return await to_thread(
184+
return await asyncio.to_thread(
184185
lambda chunk: chunk_spec.prototype.buffer.from_bytes(
185186
self._blosc_codec.encode(chunk.as_numpy_array())
186187
),

src/zarr/codecs/gzip.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING
56

67
from numcodecs.gzip import GZip
78

89
from zarr.abc.codec import BytesBytesCodec
910
from zarr.core.buffer.cpu import as_numpy_array_wrapper
10-
from zarr.core.common import JSON, parse_named_configuration, to_thread
11+
from zarr.core.common import JSON, parse_named_configuration
1112
from zarr.registry import register_codec
1213

1314
if TYPE_CHECKING:
@@ -51,7 +52,7 @@ async def _decode_single(
5152
chunk_bytes: Buffer,
5253
chunk_spec: ArraySpec,
5354
) -> Buffer:
54-
return await to_thread(
55+
return await asyncio.to_thread(
5556
as_numpy_array_wrapper, GZip(self.level).decode, chunk_bytes, chunk_spec.prototype
5657
)
5758

@@ -60,7 +61,7 @@ async def _encode_single(
6061
chunk_bytes: Buffer,
6162
chunk_spec: ArraySpec,
6263
) -> Buffer | None:
63-
return await to_thread(
64+
return await asyncio.to_thread(
6465
as_numpy_array_wrapper, GZip(self.level).encode, chunk_bytes, chunk_spec.prototype
6566
)
6667

src/zarr/codecs/zstd.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from dataclasses import dataclass
45
from functools import cached_property
56
from importlib.metadata import version
@@ -9,7 +10,7 @@
910

1011
from zarr.abc.codec import BytesBytesCodec
1112
from zarr.core.buffer.cpu import as_numpy_array_wrapper
12-
from zarr.core.common import JSON, parse_named_configuration, to_thread
13+
from zarr.core.common import JSON, parse_named_configuration
1314
from zarr.registry import register_codec
1415

1516
if TYPE_CHECKING:
@@ -73,7 +74,7 @@ async def _decode_single(
7374
chunk_bytes: Buffer,
7475
chunk_spec: ArraySpec,
7576
) -> Buffer:
76-
return await to_thread(
77+
return await asyncio.to_thread(
7778
as_numpy_array_wrapper, self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype
7879
)
7980

@@ -82,7 +83,7 @@ async def _encode_single(
8283
chunk_bytes: Buffer,
8384
chunk_spec: ArraySpec,
8485
) -> Buffer | None:
85-
return await to_thread(
86+
return await asyncio.to_thread(
8687
as_numpy_array_wrapper, self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype
8788
)
8889

src/zarr/core/common.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import contextvars
54
import functools
65
import operator
76
from collections.abc import Iterable, Mapping
@@ -10,7 +9,6 @@
109
TYPE_CHECKING,
1110
Any,
1211
Literal,
13-
ParamSpec,
1412
TypeVar,
1513
cast,
1614
overload,
@@ -62,17 +60,6 @@ async def run(item: tuple[Any]) -> V:
6260
return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items])
6361

6462

65-
P = ParamSpec("P")
66-
U = TypeVar("U")
67-
68-
69-
async def to_thread(func: Callable[P, U], /, *args: P.args, **kwargs: P.kwargs) -> U:
70-
loop = asyncio.get_running_loop()
71-
ctx = contextvars.copy_context()
72-
func_call = functools.partial(ctx.run, func, *args, **kwargs)
73-
return await loop.run_in_executor(None, func_call)
74-
75-
7663
E = TypeVar("E", bound=Enum)
7764

7865

src/zarr/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def reset(self) -> None:
4444
"default_zarr_version": 3,
4545
"array": {"order": "C"},
4646
"async": {"concurrency": 10, "timeout": None},
47+
"threading": {"max_workers": None},
4748
"json_indent": 2,
4849
"codec_pipeline": {
4950
"path": "zarr.codecs.pipeline.BatchedCodecPipeline",

src/zarr/core/sync.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import atexit
5+
import logging
46
import threading
5-
from concurrent.futures import wait
7+
from concurrent.futures import ThreadPoolExecutor, wait
68
from typing import TYPE_CHECKING, TypeVar
79

810
from typing_extensions import ParamSpec
@@ -13,6 +15,9 @@
1315
from collections.abc import AsyncIterator, Coroutine
1416
from typing import Any
1517

18+
logger = logging.getLogger(__name__)
19+
20+
1621
P = ParamSpec("P")
1722
T = TypeVar("T")
1823

@@ -23,7 +28,7 @@
2328
None
2429
] # global event loop for any non-async instance
2530
_lock: threading.Lock | None = None # global lock placeholder
26-
get_running_loop = asyncio.get_running_loop
31+
_executor: ThreadPoolExecutor | None = None # global executor placeholder
2732

2833

2934
class SyncError(Exception):
@@ -41,6 +46,51 @@ def _get_lock() -> threading.Lock:
4146
return _lock
4247

4348

49+
def _get_executor() -> ThreadPoolExecutor:
50+
"""Return Zarr Thread Pool Executor
51+
52+
The executor is allocated on first use.
53+
"""
54+
global _executor
55+
if not _executor:
56+
max_workers = config.get("threading.max_workers", None)
57+
print(max_workers)
58+
# if max_workers is not None and max_workers > 0:
59+
# raise ValueError(max_workers)
60+
_executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="zarr_pool")
61+
_get_loop().set_default_executor(_executor)
62+
return _executor
63+
64+
65+
def cleanup_resources() -> None:
66+
global _executor
67+
if _executor:
68+
_executor.shutdown(wait=True, cancel_futures=True)
69+
_executor = None
70+
71+
if loop[0] is not None:
72+
with _get_lock():
73+
# Stop the event loop safely
74+
loop[0].call_soon_threadsafe(loop[0].stop) # Stop loop from another thread
75+
if iothread[0] is not None:
76+
iothread[0].join(timeout=0.2) # Add a timeout to avoid hanging
77+
78+
if iothread[0].is_alive():
79+
logger.warning(
80+
"Thread did not finish cleanly; forcefully closing the event loop."
81+
)
82+
83+
# Forcefully close the event loop to release resources
84+
loop[0].close()
85+
86+
# dereference the loop and iothread
87+
loop[0] = None
88+
iothread[0] = None
89+
90+
91+
atexit.register(cleanup_resources)
92+
93+
4494
async def _runner(coro: Coroutine[Any, Any, T]) -> T | BaseException:
4595
"""
4696
Await a coroutine and return the result of running it. If awaiting the coroutine raises an
@@ -105,10 +155,10 @@ def _get_loop() -> asyncio.AbstractEventLoop:
105155
if loop[0] is None:
106156
new_loop = asyncio.new_event_loop()
107157
loop[0] = new_loop
108-
th = threading.Thread(target=new_loop.run_forever, name="zarrIO")
109-
th.daemon = True
110-
th.start()
111-
iothread[0] = th
158+
iothread[0] = threading.Thread(target=new_loop.run_forever, name="zarr_io")
159+
assert iothread[0] is not None
160+
iothread[0].daemon = True
161+
iothread[0].start()
112162
assert loop[0] is not None
113163
return loop[0]
114164

src/zarr/storage/local.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import io
45
import os
56
import shutil
@@ -8,7 +9,7 @@
89

910
from zarr.abc.store import ByteRangeRequest, Store
1011
from zarr.core.buffer import Buffer
11-
from zarr.core.common import concurrent_map, to_thread
12+
from zarr.core.common import concurrent_map
1213

1314
if TYPE_CHECKING:
1415
from collections.abc import AsyncGenerator, Iterable
@@ -134,7 +135,7 @@ async def get(
134135
path = self.root / key
135136

136137
try:
137-
return await to_thread(_get, path, prototype, byte_range)
138+
return await asyncio.to_thread(_get, path, prototype, byte_range)
138139
except (FileNotFoundError, IsADirectoryError, NotADirectoryError):
139140
return None
140141

@@ -159,7 +160,7 @@ async def get_partial_values(
159160
assert isinstance(key, str)
160161
path = self.root / key
161162
args.append((_get, path, prototype, byte_range))
162-
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit
163+
return await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit
163164

164165
async def set(self, key: str, value: Buffer) -> None:
165166
return await self._set(key, value)
@@ -178,7 +179,7 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
178179
if not isinstance(value, Buffer):
179180
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
180181
path = self.root / key
181-
await to_thread(_put, path, value, start=None, exclusive=exclusive)
182+
await asyncio.to_thread(_put, path, value, start=None, exclusive=exclusive)
182183

183184
async def set_partial_values(
184185
self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]]
@@ -189,19 +190,19 @@ async def set_partial_values(
189190
assert isinstance(key, str)
190191
path = self.root / key
191192
args.append((_put, path, value, start))
192-
await concurrent_map(args, to_thread, limit=None) # TODO: fix limit
193+
await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit
193194

194195
async def delete(self, key: str) -> None:
195196
self._check_writable()
196197
path = self.root / key
197198
if path.is_dir(): # TODO: support deleting directories? shutil.rmtree?
198199
shutil.rmtree(path)
199200
else:
200-
await to_thread(path.unlink, True) # Q: we may want to raise if path is missing
201+
await asyncio.to_thread(path.unlink, True) # Q: we may want to raise if path is missing
201202

202203
async def exists(self, key: str) -> bool:
203204
path = self.root / key
204-
return await to_thread(path.is_file)
205+
return await asyncio.to_thread(path.is_file)
205206

206207
async def list(self) -> AsyncGenerator[str, None]:
207208
"""Retrieve all keys in the store.

tests/v3/test_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def test_config_defaults_set() -> None:
4242
"default_zarr_version": 3,
4343
"array": {"order": "C"},
4444
"async": {"concurrency": 10, "timeout": None},
45+
"threading": {"max_workers": None},
4546
"json_indent": 2,
4647
"codec_pipeline": {
4748
"path": "zarr.codecs.pipeline.BatchedCodecPipeline",

0 commit comments

Comments
 (0)