Skip to content

feature(threading): use explicit/configurable thread pool executor #2327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/zarr/codecs/_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from typing import TYPE_CHECKING

Expand All @@ -8,7 +9,6 @@

from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec
from zarr.core.buffer import Buffer, NDBuffer, default_buffer_prototype
from zarr.core.common import to_thread
from zarr.registry import get_ndbuffer_class

if TYPE_CHECKING:
Expand All @@ -30,7 +30,7 @@ async def _decode_single(
) -> NDBuffer:
if self.compressor is not None:
chunk_numpy_array = ensure_ndarray(
await to_thread(self.compressor.decode, chunk_bytes.as_array_like())
await asyncio.to_thread(self.compressor.decode, chunk_bytes.as_array_like())
)
else:
chunk_numpy_array = ensure_ndarray(chunk_bytes.as_array_like())
Expand All @@ -54,7 +54,7 @@ async def _encode_single(
):
chunk_numpy_array = chunk_numpy_array.copy(order="A")
encoded_chunk_bytes = ensure_bytes(
await to_thread(self.compressor.encode, chunk_numpy_array)
await asyncio.to_thread(self.compressor.encode, chunk_numpy_array)
)
else:
encoded_chunk_bytes = ensure_bytes(chunk_numpy_array)
Expand All @@ -80,7 +80,7 @@ async def _decode_single(
# apply filters in reverse order
if self.filters is not None:
for filter in self.filters[::-1]:
chunk_ndarray = await to_thread(filter.decode, chunk_ndarray)
chunk_ndarray = await asyncio.to_thread(filter.decode, chunk_ndarray)

# ensure correct chunk shape
if chunk_ndarray.shape != chunk_spec.shape:
Expand All @@ -100,7 +100,7 @@ async def _encode_single(

if self.filters is not None:
for filter in self.filters:
chunk_ndarray = await to_thread(filter.encode, chunk_ndarray)
chunk_ndarray = await asyncio.to_thread(filter.encode, chunk_ndarray)

return get_ndbuffer_class().from_ndarray_like(chunk_ndarray)

Expand Down
7 changes: 4 additions & 3 deletions src/zarr/codecs/blosc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass, replace
from enum import Enum
from functools import cached_property
Expand All @@ -10,7 +11,7 @@

from zarr.abc.codec import BytesBytesCodec
from zarr.core.buffer.cpu import as_numpy_array_wrapper
from zarr.core.common import JSON, parse_enum, parse_named_configuration, to_thread
from zarr.core.common import JSON, parse_enum, parse_named_configuration
from zarr.registry import register_codec

if TYPE_CHECKING:
Expand Down Expand Up @@ -169,7 +170,7 @@ async def _decode_single(
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> Buffer:
return await to_thread(
return await asyncio.to_thread(
as_numpy_array_wrapper, self._blosc_codec.decode, chunk_bytes, chunk_spec.prototype
)

Expand All @@ -180,7 +181,7 @@ async def _encode_single(
) -> Buffer | None:
# Since blosc only support host memory, we convert the input and output of the encoding
# between numpy array and buffer
return await to_thread(
return await asyncio.to_thread(
lambda chunk: chunk_spec.prototype.buffer.from_bytes(
self._blosc_codec.encode(chunk.as_numpy_array())
),
Expand Down
7 changes: 4 additions & 3 deletions src/zarr/codecs/gzip.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from typing import TYPE_CHECKING

from numcodecs.gzip import GZip

from zarr.abc.codec import BytesBytesCodec
from zarr.core.buffer.cpu import as_numpy_array_wrapper
from zarr.core.common import JSON, parse_named_configuration, to_thread
from zarr.core.common import JSON, parse_named_configuration
from zarr.registry import register_codec

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,7 +52,7 @@ async def _decode_single(
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> Buffer:
return await to_thread(
return await asyncio.to_thread(
as_numpy_array_wrapper, GZip(self.level).decode, chunk_bytes, chunk_spec.prototype
)

Expand All @@ -60,7 +61,7 @@ async def _encode_single(
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> Buffer | None:
return await to_thread(
return await asyncio.to_thread(
as_numpy_array_wrapper, GZip(self.level).encode, chunk_bytes, chunk_spec.prototype
)

Expand Down
7 changes: 4 additions & 3 deletions src/zarr/codecs/zstd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from functools import cached_property
from importlib.metadata import version
Expand All @@ -9,7 +10,7 @@

from zarr.abc.codec import BytesBytesCodec
from zarr.core.buffer.cpu import as_numpy_array_wrapper
from zarr.core.common import JSON, parse_named_configuration, to_thread
from zarr.core.common import JSON, parse_named_configuration
from zarr.registry import register_codec

if TYPE_CHECKING:
Expand Down Expand Up @@ -73,7 +74,7 @@ async def _decode_single(
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> Buffer:
return await to_thread(
return await asyncio.to_thread(
as_numpy_array_wrapper, self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype
)

Expand All @@ -82,7 +83,7 @@ async def _encode_single(
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> Buffer | None:
return await to_thread(
return await asyncio.to_thread(
as_numpy_array_wrapper, self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype
)

Expand Down
13 changes: 0 additions & 13 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import contextvars
import functools
import operator
from collections.abc import Iterable, Mapping
Expand All @@ -10,7 +9,6 @@
TYPE_CHECKING,
Any,
Literal,
ParamSpec,
TypeVar,
cast,
overload,
Expand Down Expand Up @@ -60,17 +58,6 @@ async def run(item: tuple[Any]) -> V:
return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items])


P = ParamSpec("P")
U = TypeVar("U")


async def to_thread(func: Callable[P, U], /, *args: P.args, **kwargs: P.kwargs) -> U:
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)


E = TypeVar("E", bound=Enum)


Expand Down
1 change: 1 addition & 0 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def reset(self) -> None:
"default_zarr_version": 3,
"array": {"order": "C"},
"async": {"concurrency": 10, "timeout": None},
"threading": {"max_workers": None},
"json_indent": 2,
"codec_pipeline": {
"path": "zarr.codecs.pipeline.BatchedCodecPipeline",
Expand Down
62 changes: 56 additions & 6 deletions src/zarr/core/sync.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import asyncio
import atexit
import logging
import threading
from concurrent.futures import wait
from concurrent.futures import ThreadPoolExecutor, wait
from typing import TYPE_CHECKING, TypeVar

from typing_extensions import ParamSpec
Expand All @@ -13,6 +15,9 @@
from collections.abc import AsyncIterator, Coroutine
from typing import Any

logger = logging.getLogger(__name__)


P = ParamSpec("P")
T = TypeVar("T")

Expand All @@ -23,7 +28,7 @@
None
] # global event loop for any non-async instance
_lock: threading.Lock | None = None # global lock placeholder
get_running_loop = asyncio.get_running_loop
_executor: ThreadPoolExecutor | None = None # global executor placeholder


class SyncError(Exception):
Expand All @@ -41,6 +46,51 @@ def _get_lock() -> threading.Lock:
return _lock


def _get_executor() -> ThreadPoolExecutor:
"""Return Zarr Thread Pool Executor

The executor is allocated on first use.
"""
global _executor
if not _executor:
max_workers = config.get("threading.max_workers", None)
print(max_workers)
# if max_workers is not None and max_workers > 0:
# raise ValueError(max_workers)
_executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="zarr_pool")
_get_loop().set_default_executor(_executor)
return _executor


def cleanup_resources() -> None:
global _executor
if _executor:
_executor.shutdown(wait=True, cancel_futures=True)
_executor = None

if loop[0] is not None:
with _get_lock():
# Stop the event loop safely
loop[0].call_soon_threadsafe(loop[0].stop) # Stop loop from another thread
if iothread[0] is not None:
iothread[0].join(timeout=0.2) # Add a timeout to avoid hanging

if iothread[0].is_alive():
logger.warning(
"Thread did not finish cleanly; forcefully closing the event loop."
)

# Forcefully close the event loop to release resources
loop[0].close()

# dereference the loop and iothread
loop[0] = None
iothread[0] = None


atexit.register(cleanup_resources)


async def _runner(coro: Coroutine[Any, Any, T]) -> T | BaseException:
"""
Await a coroutine and return the result of running it. If awaiting the coroutine raises an
Expand Down Expand Up @@ -105,10 +155,10 @@ def _get_loop() -> asyncio.AbstractEventLoop:
if loop[0] is None:
new_loop = asyncio.new_event_loop()
loop[0] = new_loop
th = threading.Thread(target=new_loop.run_forever, name="zarrIO")
th.daemon = True
th.start()
iothread[0] = th
iothread[0] = threading.Thread(target=new_loop.run_forever, name="zarr_io")
assert iothread[0] is not None
iothread[0].daemon = True
iothread[0].start()
assert loop[0] is not None
return loop[0]

Expand Down
15 changes: 8 additions & 7 deletions src/zarr/storage/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import io
import os
import shutil
Expand All @@ -8,7 +9,7 @@

from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer
from zarr.core.common import concurrent_map, to_thread
from zarr.core.common import concurrent_map

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable
Expand Down Expand Up @@ -134,7 +135,7 @@ async def get(
path = self.root / key

try:
return await to_thread(_get, path, prototype, byte_range)
return await asyncio.to_thread(_get, path, prototype, byte_range)
except (FileNotFoundError, IsADirectoryError, NotADirectoryError):
return None

Expand All @@ -159,7 +160,7 @@ async def get_partial_values(
assert isinstance(key, str)
path = self.root / key
args.append((_get, path, prototype, byte_range))
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit
return await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit

async def set(self, key: str, value: Buffer) -> None:
return await self._set(key, value)
Expand All @@ -178,7 +179,7 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
if not isinstance(value, Buffer):
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
path = self.root / key
await to_thread(_put, path, value, start=None, exclusive=exclusive)
await asyncio.to_thread(_put, path, value, start=None, exclusive=exclusive)

async def set_partial_values(
self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]]
Expand All @@ -189,19 +190,19 @@ async def set_partial_values(
assert isinstance(key, str)
path = self.root / key
args.append((_put, path, value, start))
await concurrent_map(args, to_thread, limit=None) # TODO: fix limit
await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit

async def delete(self, key: str) -> None:
self._check_writable()
path = self.root / key
if path.is_dir(): # TODO: support deleting directories? shutil.rmtree?
shutil.rmtree(path)
else:
await to_thread(path.unlink, True) # Q: we may want to raise if path is missing
await asyncio.to_thread(path.unlink, True) # Q: we may want to raise if path is missing

async def exists(self, key: str) -> bool:
path = self.root / key
return await to_thread(path.is_file)
return await asyncio.to_thread(path.is_file)

async def list(self) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store.
Expand Down
1 change: 1 addition & 0 deletions tests/v3/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_config_defaults_set() -> None:
"default_zarr_version": 3,
"array": {"order": "C"},
"async": {"concurrency": 10, "timeout": None},
"threading": {"max_workers": None},
"json_indent": 2,
"codec_pipeline": {
"path": "zarr.codecs.pipeline.BatchedCodecPipeline",
Expand Down
Loading