Skip to content

Separating route securities from other dependencies. #766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Fastapi app creation."""


from typing import Any, Dict, List, Optional, Tuple, Type, Union

import attr
Expand All @@ -26,7 +25,12 @@
ItemUri,
)
from stac_fastapi.api.openapi import update_openapi
from stac_fastapi.api.routes import Scope, add_route_dependencies, create_async_endpoint
from stac_fastapi.api.routes import (
Scope,
add_route_dependencies,
add_route_securities,
create_async_endpoint,
)
from stac_fastapi.types.config import ApiSettings, Settings
from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient
from stac_fastapi.types.extension import ApiExtension
Expand Down Expand Up @@ -120,6 +124,7 @@ class StacApi:
)
)
route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[])
route_securities: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[])

def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]:
"""Get an extension.
Expand Down Expand Up @@ -225,9 +230,9 @@ def register_post_search(self):
self.router.add_api_route(
name="Search",
path="/search",
response_model=api.ItemCollection
if self.settings.enable_response_models
else None,
response_model=(
api.ItemCollection if self.settings.enable_response_models else None
),
responses={
200: {
"content": {
Expand All @@ -254,9 +259,9 @@ def register_get_search(self):
self.router.add_api_route(
name="Search",
path="/search",
response_model=api.ItemCollection
if self.settings.enable_response_models
else None,
response_model=(
api.ItemCollection if self.settings.enable_response_models else None
),
responses={
200: {
"content": {
Expand Down Expand Up @@ -312,9 +317,9 @@ def register_get_collection(self):
self.router.add_api_route(
name="Get Collection",
path="/collections/{collection_id}",
response_model=api.Collection
if self.settings.enable_response_models
else None,
response_model=(
api.Collection if self.settings.enable_response_models else None
),
responses={
200: {
"content": {
Expand Down Expand Up @@ -431,6 +436,17 @@ def add_route_dependencies(
"""
return add_route_dependencies(self.app.router.routes, scopes, dependencies)

def add_route_securities(self) -> None:
"""Add custom securities to routes.

Returns:
None
"""
return add_route_securities(
self.app.router.routes,
self.route_securities,
)

def __attrs_post_init__(self):
"""Post-init hook.

Expand Down Expand Up @@ -479,3 +495,5 @@ def __attrs_post_init__(self):
# customize route dependencies
for scopes, dependencies in self.route_dependencies:
self.add_route_dependencies(scopes=scopes, dependencies=dependencies)

self.add_route_securities()
174 changes: 172 additions & 2 deletions stac_fastapi/api/stac_fastapi/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import copy
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, Union
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypedDict, Union

from fastapi import Depends, params
from fastapi import Depends, HTTPException, params, status
from fastapi.dependencies.utils import get_parameterless_sub_dependant
from pydantic import BaseModel
from starlette.concurrency import run_in_threadpool
Expand Down Expand Up @@ -85,6 +86,175 @@ class Scope(TypedDict, total=False):
type: Optional[str]


def merge_dependencies1(*dependencies: Callable) -> Callable:
"""
This function wraps the given callables (dependencies) and
wraps them in FastAPIs Depends. It returns a function
containing these dependencies in its signature.

:param dependencies: The dependencies to wrap
:return: A callable which returns a list of the results of
the dependencies
"""

def merged_dependencies(**kwargs):
result = next((item for item in kwargs.values() if item is not None), None)
if not result:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)
return result

merged_dependencies.__signature__ = inspect.Signature( # type: ignore
parameters=[
inspect.Parameter(f"dep{key}", inspect.Parameter.KEYWORD_ONLY, default=dep)
for key, dep in enumerate(dependencies)
]
)
return merged_dependencies


def merge_dependencies2(*dependencies: Callable) -> Callable:
"""
This function wraps the given callables (dependencies) and
wraps them in FastAPIs Depends. It returns a function
containing these dependencies in its signature.

:param dependencies: The dependencies to wrap
:return: A callable which returns the first non none result of
the dependencies
"""

async def merged_dependencies(**kwargs):
for dep_key, dep in dependencies_key.items():
dep_kwargs = {
kwarg_key.removeprefix(dep_key): kwarg_value
for kwarg_key, kwarg_value in kwargs.items()
if kwarg_key.startswith(dep_key)
}

try:
result = await dep.dependency(**dep_kwargs)

if result:
return result

except HTTPException as e:
if e.status_code != 401:
raise e

continue

raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
)

dependencies_key = {}
sub_dependencies = []
for key, dep in enumerate(dependencies):
dependencies_key[f"dep{key}_"] = dep

if isinstance(dep, params.Depends):
for sub_key, parameter in inspect.signature(
dep.dependency
).parameters.items():
dummy_parameter = parameter.replace(name=f"dep{key}_{sub_key}")
sub_dependencies.append(dummy_parameter)

merged_dependencies.__signature__ = inspect.Signature( # type: ignore
parameters=sub_dependencies
)
return merged_dependencies


def map_route_securities(
route_securites: List[Tuple[List[Scope], List[Depends]]],
) -> Dict:
"""Map route securities.

Allows a developer to add dependencies to a route after the route has been
defined.

"*" can be used for path or method to match all allowed routes.

Returns:
None
"""
mapped_scopes = defaultdict(list)

for scopes, securities in route_securites:
for scope in scopes:
for method in scope["method"]:
mapped_scopes[scope["path"]][method].append(securities)

return mapped_scopes


def add_route_securities(
routes: List[BaseRoute], route_securites: List[Tuple[List[Scope], List[Depends]]]
) -> None:
"""Add securities to routes.

Allows a developer to add securities to a route after the route has been
defined.

"*" can be used for path or method to match all allowed routes or methods.

Returns:
None
"""
mapped_scopes = map_route_securities(route_securites=route_securites)

default_scope = mapped_scopes.pop("*", {})

for route in routes:
if not hasattr(route, "dependant"):
continue

scope = mapped_scopes.get(route.path, defaultdict(list))
route_securities = []

for default_method, default_security in default_scope.items():
scope[default_method].append(default_security)

for method, security in scope.items():
method = scope["method"]
if method == "*":
method = list(route.methods)[0]

match, _ = route.matches(
{"type": "http", **{"routes": [route.path], "method": [method]}}
)
if match:
route_securities.append(security)

route_security = (
Depends(merge_dependencies2(*route_securities))
if len(route_securities) > 1
else route_securities[0]
)

# route_security = Depends(merge_dependencies1(*route_securities))

# Mimicking how APIRoute handles dependencies:
# https://github.com/tiangolo/fastapi/blob/1760da0efa55585c19835d81afa8ca386036c325/fastapi/routing.py#L408-L412
route.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(
depends=route_security, path=route.path_format
),
)

# Register dependencies directly on route so that they aren't ignored if
# the routes are later associated with an app (e.g.
# app.include_router(router))
# https://github.com/tiangolo/fastapi/blob/58ab733f19846b4875c5b79bfb1f4d1cb7f4823f/fastapi/applications.py#L337-L360
# https://github.com/tiangolo/fastapi/blob/58ab733f19846b4875c5b79bfb1f4d1cb7f4823f/fastapi/routing.py#L677-L678
route.dependencies.extend(route_securities)


def add_route_dependencies(
routes: List[BaseRoute], scopes: List[Scope], dependencies=List[params.Depends]
) -> None:
Expand Down
Loading