Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7ee50fd
Add abstract base class and a no op implementation to enable task can…
VegetarianOrc Oct 29, 2025
dbaca99
fix some linter errors
VegetarianOrc Oct 29, 2025
847d9cb
Some PR feedback. Up min python version to 3.10
VegetarianOrc Oct 29, 2025
13fde5b
Update some docs to more clearly highlight expected behavior of opera…
VegetarianOrc Nov 3, 2025
b18edcc
Simple logging interceptor working with an InterceptedOperationHandle…
VegetarianOrc Nov 7, 2025
0f76a23
Update test to confirm interceptors are applied in the order provided…
VegetarianOrc Nov 7, 2025
40b4eb3
Do some renaming. Add some doc strings. remove type aliases that woun…
VegetarianOrc Nov 7, 2025
139173b
Remove request_deadline as that's part of a different PR
VegetarianOrc Nov 7, 2025
6561cd3
remove some unused imports
VegetarianOrc Nov 7, 2025
0c8100e
Use public export in tests
VegetarianOrc Nov 7, 2025
0d37c34
Fix some linter errors
VegetarianOrc Nov 7, 2025
60ef746
use cancellation in tests after rebasing to support new python
VegetarianOrc Nov 7, 2025
be5d42a
fix docstring errors
VegetarianOrc Nov 7, 2025
e66c5fc
merge main
VegetarianOrc Nov 7, 2025
2935dbc
rename interceptor to middleware. Expose operation context to middleware
VegetarianOrc Nov 11, 2025
74ca843
fix formatting and linter errors
VegetarianOrc Nov 11, 2025
835d43d
Remove return repetitive types in OperationHandler.start. Make Operat…
VegetarianOrc Nov 14, 2025
4205c93
Move deploy-docs to it's own workflow that runs on push to main
VegetarianOrc Nov 14, 2025
d933f83
Fix workflow name in deploy-docs
VegetarianOrc Nov 14, 2025
96ddecf
export LazyValueT and Serializer from _serializer.py
VegetarianOrc Nov 14, 2025
59c96d8
remove the work 'docs' from the 'lint-test' job
VegetarianOrc Nov 14, 2025
5f2a399
Rename AwaitableOperationHandler to MiddlewareSafeOperationHandler
VegetarianOrc Nov 15, 2025
ae3b58f
Run formatter
VegetarianOrc Nov 15, 2025
e880657
Merge branch 'main' into interceptors
VegetarianOrc Nov 17, 2025
b42c0e9
remove generic args in MiddlewareSafeOperationHandler since it by def…
VegetarianOrc Nov 20, 2025
0f37ce8
Finish removing generic args from MiddlewareSafeOperationHandler
VegetarianOrc Nov 20, 2025
bde8337
Update old reference from 'interceptors' -> 'middleware'
VegetarianOrc Nov 24, 2025
875e2ea
Remove _all_ reference to interceptors
VegetarianOrc Nov 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 4 additions & 35 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ name: CI
on:
pull_request:
push:
branches: [ main ]
branches:
- main

jobs:
lint-test-docs:
lint-test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.10', '3.13', '3.14']
python-version: ['3.10', '3.14']
os: [ubuntu-latest, macos-latest, windows-latest]

steps:
Expand Down Expand Up @@ -38,35 +39,3 @@ jobs:
with:
name: coverage-html-report-${{ matrix.os }}-${{ matrix.python-version }}
path: coverage_html_report/

deploy-docs:
runs-on: ubuntu-latest
needs: lint-test-docs
# TODO(preview): deploy on releases only
permissions:
contents: read
pages: write
id-token: write

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v6
with:
python-version: '3.10'

- name: Install dependencies
run: uv sync

- name: Build API docs
run: uv run poe docs

- name: Upload docs to GitHub Pages
uses: actions/upload-pages-artifact@v3
with:
path: apidocs

- name: Deploy to GitHub Pages
uses: actions/deploy-pages@v4
37 changes: 37 additions & 0 deletions .github/workflows/deploy-docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Deploy Docs

on:
push:
branches:
- main

jobs:
deploy-docs:
runs-on: ubuntu-latest
permissions:
contents: read
pages: write
id-token: write

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v6
with:
python-version: '3.10'

- name: Install dependencies
run: uv sync

- name: Build API docs
run: uv run poe docs

- name: Upload docs to GitHub Pages
uses: actions/upload-pages-artifact@v3
with:
path: apidocs

- name: Deploy to GitHub Pages
uses: actions/deploy-pages@v4
4 changes: 3 additions & 1 deletion src/nexusrpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
OperationErrorState,
OutputT,
)
from ._serializer import Content, LazyValue
from ._serializer import Content, LazyValue, LazyValueT, Serializer
from ._service import Operation, OperationDefinition, ServiceDefinition, service
from ._util import (
get_operation,
Expand All @@ -42,12 +42,14 @@
"HandlerErrorType",
"InputT",
"LazyValue",
"LazyValueT",
"Link",
"Operation",
"OperationDefinition",
"OperationError",
"OperationErrorState",
"OutputT",
"Serializer",
"service",
"ServiceDefinition",
"set_operation",
Expand Down
7 changes: 5 additions & 2 deletions src/nexusrpc/handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
StartOperationResultAsync,
StartOperationResultSync,
)
from ._core import Handler as Handler
from ._core import Handler, OperationHandlerMiddleware
from ._decorators import operation_handler, service_handler, sync_operation
from ._operation_handler import OperationHandler as OperationHandler
from ._operation_handler import MiddlewareSafeOperationHandler, OperationHandler

__all__ = [
"MiddlewareSafeOperationHandler",
"CancelOperationContext",
"Handler",
"OperationContext",
"OperationHandler",
"OperationTaskCancellation",
"OperationHandlerMiddleware",
"operation_handler",
"service_handler",
"StartOperationContext",
"StartOperationResultAsync",
Expand Down
110 changes: 93 additions & 17 deletions src/nexusrpc/handler/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, cast

from typing_extensions import Self, TypeGuard

Expand All @@ -113,11 +113,13 @@

from ._common import (
CancelOperationContext,
OperationContext,
StartOperationContext,
StartOperationResultAsync,
StartOperationResultSync,
)
from ._operation_handler import (
MiddlewareSafeOperationHandler,
OperationHandler,
collect_operation_handler_factories_by_method_name,
)
Expand Down Expand Up @@ -248,7 +250,9 @@ def __init__(
self,
user_service_handlers: Sequence[Any],
executor: Optional[concurrent.futures.Executor] = None,
middleware: Sequence[OperationHandlerMiddleware] | None = None,
):
self._middleware = cast(Sequence[OperationHandlerMiddleware], middleware or [])
super().__init__(user_service_handlers, executor=executor)
if not self.executor:
self._validate_all_operation_handlers_are_async()
Expand All @@ -268,17 +272,11 @@ async def start_operation(
input: The input to the operation, as a LazyValue.
"""
service_handler = self._get_service_handler(ctx.service)
op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage]
op_handler = self._get_operation_handler(ctx, service_handler, ctx.operation)

op_defn = service_handler.service.operation_definitions[ctx.operation]
deserialized_input = await input.consume(as_type=op_defn.input_type)
# TODO(preview): apply middleware stack
if is_async_callable(op_handler.start):
return await op_handler.start(ctx, deserialized_input)
else:
assert self.executor
return await self.executor.submit_to_event_loop(
op_handler.start, ctx, deserialized_input
)
return await op_handler.start(ctx, deserialized_input)

async def cancel_operation(self, ctx: CancelOperationContext, token: str) -> None:
"""Handle a Cancel Operation request.
Expand All @@ -288,12 +286,23 @@ async def cancel_operation(self, ctx: CancelOperationContext, token: str) -> Non
token: The operation token.
"""
service_handler = self._get_service_handler(ctx.service)
op_handler = service_handler._get_operation_handler(ctx.operation) # pyright: ignore[reportPrivateUsage]
if is_async_callable(op_handler.cancel):
return await op_handler.cancel(ctx, token)
else:
assert self.executor
return self.executor.submit(op_handler.cancel, ctx, token).result()
op_handler = self._get_operation_handler(ctx, service_handler, ctx.operation)
return await op_handler.cancel(ctx, token)

def _get_operation_handler(
self, ctx: OperationContext, service_handler: ServiceHandler, operation: str
) -> MiddlewareSafeOperationHandler:
"""
Get the specified handler for the specified operation from the given service_handler and apply all middleware.
"""
op_handler: MiddlewareSafeOperationHandler = _EnsuredAwaitableOperationHandler(
self.executor, service_handler.get_operation_handler(operation)
)

for middleware in reversed(self._middleware):
op_handler = middleware.intercept(ctx, op_handler)

return op_handler

def _validate_all_operation_handlers_are_async(self) -> None:
for service_handler in self.service_handlers.values():
Expand Down Expand Up @@ -360,7 +369,7 @@ def from_user_instance(cls, user_instance: Any) -> Self:
operation_handlers=op_handlers,
)

def _get_operation_handler(self, operation_name: str) -> OperationHandler[Any, Any]:
def get_operation_handler(self, operation_name: str) -> OperationHandler[Any, Any]:
"""Return an operation handler, given the operation name."""
if operation_name not in self.service.operation_definitions:
raise HandlerError(
Expand Down Expand Up @@ -401,3 +410,70 @@ def submit(
self, fn: Callable[..., Any], *args: Any
) -> concurrent.futures.Future[Any]:
return self._executor.submit(fn, *args)


class OperationHandlerMiddleware(ABC):
"""
Middleware for operation handlers.

This should be extended by any operation handler middelware.
"""

@abstractmethod
def intercept(
self,
ctx: OperationContext, # type: ignore[reportUnusedParameter]
next: MiddlewareSafeOperationHandler,
) -> MiddlewareSafeOperationHandler:
"""
Method called for intercepting operation handlers.

Args:
ctx: The :py:class:`OperationContext` that will be passed to the operation handler.
next: The underlying operation handler that this middleware
should delegate to.

Returns:
The new middleware that will be used to invoke
:py:attr:`OperationHandler.start` or :py:attr:`OperationHandler.cancel`.
"""
...


class _EnsuredAwaitableOperationHandler(MiddlewareSafeOperationHandler):
"""
An :py:class:`AwaitableOperationHandler` that wraps an :py:class:`OperationHandler` and uses an :py:class:`_Executor` to ensure
that the :py:attr:`start` and :py:attr:`cancel` methods are awaitable.
"""

def __init__(
self,
executor: _Executor | None,
op_handler: OperationHandler[Any, Any],
):
self._executor = executor
self._op_handler = op_handler

async def start(
self, ctx: StartOperationContext, input: Any
) -> StartOperationResultSync[Any] | StartOperationResultAsync:
"""
Start the operation using the wrapped :py:class:`OperationHandler`.
"""
if is_async_callable(self._op_handler.start):
return await self._op_handler.start(ctx, input)
else:
assert self._executor
return await self._executor.submit_to_event_loop(
self._op_handler.start, ctx, input
)

async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
"""
Cancel an operation using the wrapped :py:class:`OperationHandler`.
"""
if is_async_callable(self._op_handler.cancel):
return await self._op_handler.cancel(ctx, token)
else:
assert self._executor
return self._executor.submit(self._op_handler.cancel, ctx, token).result()
42 changes: 32 additions & 10 deletions src/nexusrpc/handler/_operation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
from abc import ABC, abstractmethod
from collections.abc import Awaitable
from typing import Any, Callable, Generic, Optional, Union
from typing import Any, Callable, Generic, Optional

from nexusrpc._common import InputT, OutputT, ServiceHandlerT
from nexusrpc._service import Operation, OperationDefinition, ServiceDefinition
Expand Down Expand Up @@ -39,12 +39,11 @@ class OperationHandler(ABC, Generic[InputT, OutputT]):
@abstractmethod
def start(
self, ctx: StartOperationContext, input: InputT
) -> Union[
StartOperationResultSync[OutputT],
Awaitable[StartOperationResultSync[OutputT]],
StartOperationResultAsync,
Awaitable[StartOperationResultAsync],
]:
) -> (
StartOperationResultSync[OutputT]
| StartOperationResultAsync
| Awaitable[StartOperationResultSync[OutputT] | StartOperationResultAsync]
):
"""
Start the operation, completing either synchronously or asynchronously.

Expand All @@ -54,9 +53,7 @@ def start(
...

@abstractmethod
def cancel(
self, ctx: CancelOperationContext, token: str
) -> Union[None, Awaitable[None]]:
def cancel(self, ctx: CancelOperationContext, token: str) -> None | Awaitable[None]:
"""
Cancel the operation.
"""
Expand Down Expand Up @@ -104,6 +101,31 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
)


class MiddlewareSafeOperationHandler(OperationHandler[Any, Any], ABC):
"""
An :py:class:`OperationHandler` where :py:attr:`start` and :py:attr:`cancel`
can be awaited by an async runtime. It can produce a result synchronously by returning
:py:class:`StartOperationResultSync` or asynchronously by returning :py:class:`StartOperationResultAsync`
in the same fashion that :py:class:`OperationHandler` does.
"""

@abstractmethod
async def start(
self, ctx: StartOperationContext, input: Any
) -> StartOperationResultSync[Any] | StartOperationResultAsync:
"""
Start the operation and return it's result or an async token.
"""
...

@abstractmethod
async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
"""
Cancel an in progress operation identified by the given token.
"""
...


def collect_operation_handler_factories_by_method_name(
user_service_cls: type[ServiceHandlerT],
service: Optional[ServiceDefinition],
Expand Down
2 changes: 1 addition & 1 deletion tests/handler/test_async_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
OperationHandler,
StartOperationContext,
StartOperationResultAsync,
operation_handler,
service_handler,
)
from nexusrpc.handler._decorators import operation_handler
from tests.helpers import DummySerializer, TestOperationTaskCancellation

_operation_results: dict[str, int] = {}
Expand Down
Loading