Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9f874fc
Draft coffea.compute sketch
nsmith- Nov 5, 2025
7333d5d
Add more notes
nsmith- Nov 5, 2025
dccdce8
All itertools
nsmith- Nov 5, 2025
21c5b5c
Catch exceptions
nsmith- Nov 5, 2025
554621d
Split protocol out
nsmith- Nov 5, 2025
02ec10a
Re-organize and make tests
nsmith- Nov 6, 2025
dce40a9
profile
nsmith- Nov 6, 2025
c2ad4e2
Add some data tests
nsmith- Nov 6, 2025
dcc7bd9
Add mypy type checking on submodule
nsmith- Nov 6, 2025
267379f
Genericize the protocols
nsmith- Nov 7, 2025
88217db
Genericize data
nsmith- Nov 7, 2025
104eeb4
Mock up a data preparation workflow
nsmith- Nov 7, 2025
1f13488
A way to include typed metadata in function context
nsmith- Nov 9, 2025
733bab7
Consolidate a bit the classes
nsmith- Nov 9, 2025
3223c2f
Refactor context, add groups
nsmith- Nov 10, 2025
9e7a7d4
lint
nsmith- Nov 10, 2025
d70db53
py3.10 compat
nsmith- Nov 10, 2025
0378692
Move to py3.10 CI
nsmith- Nov 11, 2025
0f28ab8
Merge branch 'master' into compute
nsmith- Nov 11, 2025
91ee7a1
oops
nsmith- Nov 11, 2025
f129cbe
Merge branch 'master' into compute
nsmith- Nov 17, 2025
6c0ea5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2025
f50f736
Correctly import Self
nsmith- Nov 17, 2025
ae71f9e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2025
c67e5bf
more robust test
nsmith- Nov 17, 2025
4c85e1f
wow these github runners are flaky
nsmith- Nov 18, 2025
18314a1
Re-organize data into separate modules
nsmith- Nov 18, 2025
269efbe
Add small notes on result()
nsmith- Nov 19, 2025
77291a1
Merge branch 'master' into compute
nsmith- Nov 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ repos:
- id: setup-cfg-fmt

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.3
rev: v0.14.4
hooks:
- id: ruff
args: [--fix, --show-fixes]
Expand All @@ -59,3 +59,10 @@ repos:
hooks:
- id: codespell
args: ["--skip=*.ipynb,*.svg","-L HEP,hist,Hist,nd,SubJet,subjet,Subjet,PTD,ptd,fPt,fpt,Ser,ser,hda"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
hooks:
- id: mypy
pass_filenames: false # to allow mypy to respect pyproject.toml config
additional_dependencies: []
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ dependencies = [
"fsspec",
"rich",
"ipywidgets",
"more-itertools",
"typing-extensions;python_version<'3.11'",
]
dynamic = ["version"]

Expand Down Expand Up @@ -137,3 +139,11 @@ line-length = 160

[tool.ruff.lint]
ignore = ["F403", "F405", "E402"]

[tool.mypy]
python_version = "3.10"
files = "src/coffea/compute"

[[tool.mypy.overrides]]
module = ["uproot.*"]
follow_untyped_imports = true
6 changes: 2 additions & 4 deletions src/coffea/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from . import _version
from ._version import __version__

__version__ = _version.__version__

__all__ = ["deprecations_as_errors"]
__all__ = ["__version__"]
8 changes: 8 additions & 0 deletions src/coffea/compute/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from coffea.compute import backends, data, errors, func

__all__ = [
"data",
"errors",
"func",
"backends",
]
6 changes: 6 additions & 0 deletions src/coffea/compute/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Compute backends

The compute backends are implemented in their own submodules.
They will not be imported here in case they have additional dependencies
that are not always available.
"""
300 changes: 300 additions & 0 deletions src/coffea/compute/backends/threaded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
from collections.abc import Iterator
from dataclasses import dataclass, field
from itertools import starmap
from queue import Queue
from threading import Condition, Thread
from types import TracebackType
from typing import TYPE_CHECKING, Generic

from coffea.compute.errors import (
ErrorAction,
ErrorPolicy,
FailedTaskElement,
TaskElement,
)
from coffea.compute.protocol import (
Backend,
Computable,
EmptyResult,
InputT,
ResultT,
Task,
TaskStatus,
WorkElement,
)


@dataclass
class Continuation(Generic[InputT, ResultT]):
original: Computable[InputT, ResultT]
"The original computable item."
status: TaskStatus
"The status of the original computation task."
failed_indices: frozenset[int]
"Indices of task elements that failed in the original computation."
continue_at: int
"Index to continue processing from, in the case where the original task was cancelled."

def __iter__(self) -> Iterator[WorkElement[InputT, ResultT]]:
for i, task_element in enumerate(self.original):
if i in self.failed_indices or i >= self.continue_at:
yield task_element


@dataclass
class _TaskState(Generic[InputT, ResultT]):
output: ResultT | EmptyResult = EmptyResult()
next_index: int = 0
failures: list[FailedTaskElement[InputT, ResultT]] = field(default_factory=list)
status: TaskStatus = TaskStatus.PENDING

def get_continuation(
self, original: Computable[InputT, ResultT]
) -> Continuation[InputT, ResultT]:
return Continuation(
original=original,
status=self.status,
failed_indices=frozenset(element.index for element in self.failures),
continue_at=self.next_index,
)


def _try_advance(
state: _TaskState[InputT, ResultT],
element: TaskElement[InputT, ResultT],
error_policy: ErrorPolicy,
) -> _TaskState[InputT, ResultT]:
try:
result = element()
except Exception as ex:
new_element, action = error_policy.first_action(element, ex)
if action == ErrorAction.CANCEL:
return _TaskState(
output=state.output,
next_index=state.next_index + 1,
failures=state.failures + [new_element],
status=TaskStatus.CANCELLED,
)
elif action == ErrorAction.CONTINUE:
return _TaskState(
output=state.output,
next_index=state.next_index + 1,
failures=state.failures + [new_element],
status=TaskStatus.RUNNING,
)
else:
# This could use a more sophisticated merging strategy
return _TaskState(
output=state.output + result,
next_index=state.next_index + 1,
failures=state.failures,
status=TaskStatus.RUNNING,
)
# Now handle retries
assert action == ErrorAction.RETRY # (proven by control flow)
while True:
try:
result = new_element()
except Exception as ex:
new_element, action = error_policy.retry_action(new_element, ex)
if action == ErrorAction.CANCEL:
return _TaskState(
output=state.output,
next_index=state.next_index + 1,
failures=state.failures + [new_element],
status=TaskStatus.CANCELLED,
)
elif action == ErrorAction.CONTINUE:
return _TaskState(
output=state.output,
next_index=state.next_index + 1,
failures=state.failures + [new_element],
status=TaskStatus.RUNNING,
)
else:
# This could use a more sophisticated merging strategy
return _TaskState(
output=state.output + result,
next_index=state.next_index + 1,
failures=state.failures,
status=TaskStatus.RUNNING,
)
assert action == ErrorAction.RETRY


class ThreadedTask(Generic[InputT, ResultT]):
item: Computable[InputT, ResultT]
error_policy: ErrorPolicy
_iter: Iterator[TaskElement[InputT, ResultT]]
_state: _TaskState[InputT, ResultT]
"To be modified only under _cv lock"
_cv: Condition

def __init__(
self, item: Computable[InputT, ResultT], error_policy: ErrorPolicy
) -> None:
self.item = item
self.error_policy = error_policy
self._iter = starmap(TaskElement, enumerate(item))
self._state = _TaskState()
self._cv = Condition()

def result(self) -> ResultT:
# TODO: if backend is shutdown without waiting on all tasks, raise an error here
self.wait()
if self._state.failures:
# Reraise the first error encountered
msg = (
f"Computation failed with {len(self._state.failures)} errors;\n"
" use Task.partial_result() to access partial results and a\n"
" continuation Computable for the remaining work.\n"
" You can also adjust the ErrorPolicy to continue on certain errors.\n"
" The first error is shown in the chained exception above."
)
raise RuntimeError(msg) from self._state.failures[0].exception
out = self._state.output
assert not isinstance(out, EmptyResult)
return out

def partial_result(
self,
) -> tuple[ResultT | EmptyResult, Continuation[InputT, ResultT]]:
# Hold lock so we get a consistent snapshot of state
with self._cv:
return self._state.output, self._state.get_continuation(self.item)

def wait(self) -> None:
with self._cv:
self._cv.wait_for(self.done)

def status(self) -> TaskStatus:
return self._state.status

def done(self) -> bool:
return self._state.status.done()

def cancel(self) -> None:
with self._cv:
self._state.status = TaskStatus.CANCELLED
self._cv.notify_all()

def _run(self) -> None:
"""Run the task to completion.
This is intended to be called by a single worker thread.
"""
for task_element in self._iter:
next_state = _try_advance(self._state, task_element, self.error_policy)
with self._cv:
# First check if we were aborted while working
if self._state.status == TaskStatus.CANCELLED:
self._cv.notify_all()
return
self._state = next_state
if self._state.status.done():
self._cv.notify_all()
return
with self._cv:
assert self._state.status == TaskStatus.RUNNING
if self._state.failures:
self._state.status = TaskStatus.INCOMPLETE
else:
self._state.status = TaskStatus.COMPLETE
self._cv.notify_all()


class _Shutdown:
"""A sentinel to signal worker threads to exit.

In python 3.13+, we can use queue.ShutDown directly.
"""

pass


def _work(task_queue: Queue[ThreadedTask]) -> None: # type: ignore[type-arg]
while True:
task = task_queue.get()
if isinstance(task, _Shutdown):
task_queue.task_done()
break
try:
task._run()
except Exception:
# Any exceptions not caught by the task itself are bugs in the backend
# TODO: find a way to report these in the user thread
task.cancel()
task_queue.task_done()


class SingleThreadedBackend:
_task_queue: Queue[ThreadedTask | _Shutdown] | None # type: ignore[type-arg]
_thread: Thread | None

def __init__(self) -> None:
self._task_queue = None
self._thread = None

def __enter__(self) -> "RunningSingleThreadedBackend":
self._task_queue = Queue()
self._thread = Thread(
target=_work,
name="SingleThreadedBackend",
args=(self._task_queue,),
)
self._thread.start()
return RunningSingleThreadedBackend(self)

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
assert self._thread and self._task_queue
self._task_queue.put(_Shutdown())
self._thread.join()
self._task_queue = None
self._thread = None


class RunningSingleThreadedBackend:
_backend: SingleThreadedBackend

def __init__(self, backend: SingleThreadedBackend): # type: ignore[type-arg]
self._backend = backend

def compute(
self,
item: Computable[InputT, ResultT],
/,
error_policy: ErrorPolicy = ErrorPolicy(),
) -> ThreadedTask[InputT, ResultT]:
if self._backend._task_queue is None:
raise RuntimeError(
"Cannot compute on a backend that has been exited from its context manager"
)
if hasattr(item, "__next__"):
raise TypeError("Computable items must be iterables, not iterators")
task = ThreadedTask(item, error_policy)
self._backend._task_queue.put(task)
return task

def wait_all(self, progress: bool = False) -> None:
"""Wait for all tasks in the backend to complete.

Parameters
----------
progress : bool, optional
If True, display a progress bar while waiting, by default False.
"""
if progress:
raise NotImplementedError("Progress bars are not yet implemented")
else:
if self._backend._task_queue:
self._backend._task_queue.join()


if TYPE_CHECKING:
# TODO: is this the best way to do this?
check1: type[Task] = ThreadedTask
check2: type[Backend] = SingleThreadedBackend
Loading
Loading