Skip to content

Commit d7f9686

Browse files
authored
fix: apply api middleware correctly (#129)
1 parent 46842c3 commit d7f9686

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

nitric/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,9 @@ async def chained_middleware(ctx: C, nxt: Optional[Middleware[C]] = None) -> C:
439439

440440
return chained_middleware
441441

442-
middleware_chain = functools.reduce(reduce_chain, reversed(middlewares)) # type: ignore
442+
middleware_chain = functools.reduce(reduce_chain, reversed(middlewares), last_middleware) # type: ignore
443443
# type ignored because mypy appears to misidentify the correct return type
444-
return await middleware_chain(ctx, last_middleware) # type: ignore
444+
return await middleware_chain(ctx) # type: ignore
445445

446446
return composed
447447

nitric/resources/apis.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ def _route(self, match: str, opts: Optional[RouteOptions] = None) -> Route:
201201
if opts is None:
202202
opts = RouteOptions()
203203

204+
if self.middleware is not None:
205+
opts.middleware = self.middleware + opts.middleware
206+
204207
r = Route(self, match, opts)
205208
self.routes.append(r)
206209
return r
@@ -339,6 +342,13 @@ def method(
339342
self, methods: List[HttpMethod], *middleware: HttpMiddleware | HttpHandler, opts: Optional[MethodOptions] = None
340343
) -> None:
341344
"""Register middleware for multiple HTTP Methods."""
345+
346+
# ensure route/api middlewares are added
347+
middleware = (
348+
*self.middleware,
349+
*middleware
350+
)
351+
342352
Method(self, methods, *middleware, opts=opts if opts else MethodOptions())
343353

344354
def get(self, *middleware: HttpMiddleware | HttpHandler, opts: Optional[MethodOptions] = None) -> None:

tests/resources/test_apis.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
# from nitric.faas import HttpMethod, MethodOptions, ApiWorkerOptions
2828
from nitric.resources import api, ApiOptions, JwtSecurityDefinition
29-
from nitric.resources.apis import MethodOptions, ScopedOidcOptions, oidc_rule
29+
from nitric.resources.apis import MethodOptions, ScopedOidcOptions, oidc_rule, HttpMiddleware
3030
from nitric.proto.resources.v1 import (
3131
ApiOpenIdConnectionDefinition,
3232
ApiSecurityDefinitionResource,
@@ -40,6 +40,7 @@
4040
from nitric.proto.apis.v1 import ApiDetailsResponse, ApiDetailsRequest, ApiWorkerScopes
4141

4242
from nitric.context import (
43+
HttpContext,
4344
HttpMethod,
4445
)
4546

@@ -221,6 +222,23 @@ def test_api_route(self):
221222
assert test_route.middleware == []
222223
assert test_route.api.name == test_api.name
223224

225+
def test_api_route_middleware(self):
226+
mock_declare = AsyncMock()
227+
mock_response = Object()
228+
mock_declare.return_value = mock_response
229+
230+
async def middleware_test(ctx: HttpContext, nxt: HttpMiddleware):
231+
return nxt(ctx)
232+
233+
with patch("nitric.proto.resources.v1.ResourcesStub.declare", mock_declare):
234+
test_api = api("test-api-route-middleware", ApiOptions(path="/api/v2/", middleware=[middleware_test]))
235+
236+
test_route = test_api._route("/test")
237+
238+
assert len(test_api.middleware) == 1
239+
assert len(test_route.middleware) == 1
240+
241+
224242
def test_define_route(self):
225243
mock_declare = AsyncMock()
226244
mock_response = Object()

0 commit comments

Comments
 (0)