Skip to content
This repository was archived by the owner on Sep 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dispatch/auth/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND

from dispatch.enums import UserRoles, Visibility
from dispatch.auth.service import get_current_user
from dispatch.enums import UserRoles, Visibility
from dispatch.case import service as case_service
from dispatch.case.models import Case
from dispatch.incident import service as incident_service
Expand Down
5 changes: 4 additions & 1 deletion src/dispatch/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
:license: Apache, see LICENSE for more details.
"""
import logging
from typing import Optional
from typing import Annotated, Optional

from fastapi import HTTPException, Depends
from starlette.requests import Request
Expand Down Expand Up @@ -248,6 +248,9 @@ def get_current_user(request: Request) -> DispatchUser:
)


CurrentUser = Annotated[DispatchUser, Depends(get_current_user)]


def get_current_role(
request: Request, current_user: DispatchUser = Depends(get_current_user)
) -> UserRoles:
Expand Down
21 changes: 10 additions & 11 deletions src/dispatch/auth/views.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from sqlalchemy.orm import Session

from dispatch.config import DISPATCH_AUTH_REGISTRATION_ENABLED

from dispatch.auth.permissions import (
OrganizationMemberPermission,
PermissionsDependency,
)
from dispatch.auth.service import CurrentUser
from dispatch.exceptions import (
InvalidConfigurationError,
InvalidPasswordError,
InvalidUsernameError,
)
from dispatch.database.core import get_db
from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.enums import UserRoles
from dispatch.models import OrganizationSlug, PrimaryKey
from dispatch.organization.models import OrganizationRead

from .models import (
DispatchUser,
UserLogin,
UserLoginResponse,
UserOrganization,
Expand All @@ -30,7 +29,7 @@
UserRegisterResponse,
UserUpdate,
)
from .service import get, get_by_email, update, create, get_current_user
from .service import get, get_by_email, update, create


auth_router = APIRouter()
Expand Down Expand Up @@ -74,7 +73,7 @@ def get_users(*, organization: OrganizationSlug, common: dict = Depends(common_p


@user_router.get("/{user_id}", response_model=UserRead)
def get_user(*, db_session: Session = Depends(get_db), user_id: PrimaryKey):
def get_user(*, db_session: DbSession, user_id: PrimaryKey):
"""Get a user."""
user = get(db_session=db_session, user_id=user_id)
if not user:
Expand All @@ -92,11 +91,11 @@ def get_user(*, db_session: Session = Depends(get_db), user_id: PrimaryKey):
)
def update_user(
*,
db_session: Session = Depends(get_db),
db_session: DbSession,
user_id: PrimaryKey,
organization: OrganizationSlug,
user_in: UserUpdate,
current_user: DispatchUser = Depends(get_current_user),
current_user: CurrentUser,
):
"""Update a user."""
user = get(db_session=db_session, user_id=user_id)
Expand Down Expand Up @@ -131,8 +130,8 @@ def update_user(
@auth_router.get("/me", response_model=UserRead)
def get_me(
*,
db_session: Session = Depends(get_db),
current_user: DispatchUser = Depends(get_current_user),
db_session: DbSession,
current_user: CurrentUser,
):
return current_user

Expand All @@ -141,7 +140,7 @@ def get_me(
def login_user(
user_in: UserLogin,
organization: OrganizationSlug,
db_session: Session = Depends(get_db),
db_session: DbSession,
):
user = get_by_email(db_session=db_session, email=user_in.email)
if user and user.check_password(user_in.password):
Expand Down Expand Up @@ -174,7 +173,7 @@ def login_user(
def register_user(
user_in: UserRegister,
organization: OrganizationSlug,
db_session: Session = Depends(get_db),
db_session: DbSession,
):
user = get_by_email(db_session=db_session, email=user_in.email)
if user:
Expand Down
9 changes: 4 additions & 5 deletions src/dispatch/case/priority/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session

from dispatch.database.core import get_db
from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.auth.permissions import SensitiveProjectActionPermission, PermissionsDependency
from dispatch.models import PrimaryKey
Expand Down Expand Up @@ -31,7 +30,7 @@ def get_case_priorities(*, common: dict = Depends(common_parameters)):
)
def create_case_priority(
*,
db_session: Session = Depends(get_db),
db_session: DbSession,
case_priority_in: CasePriorityCreate,
):
"""Creates a new case priority."""
Expand All @@ -46,7 +45,7 @@ def create_case_priority(
)
def update_case_priority(
*,
db_session: Session = Depends(get_db),
db_session: DbSession,
case_priority_id: PrimaryKey,
case_priority_in: CasePriorityUpdate,
):
Expand All @@ -67,7 +66,7 @@ def update_case_priority(


@router.get("/{case_priority_id}", response_model=CasePriorityRead)
def get_case_priority(*, db_session: Session = Depends(get_db), case_priority_id: PrimaryKey):
def get_case_priority(*, db_session: DbSession, case_priority_id: PrimaryKey):
"""Gets a case priority."""
case_priority = get(db_session=db_session, case_priority_id=case_priority_id)
if not case_priority:
Expand Down
9 changes: 4 additions & 5 deletions src/dispatch/case/severity/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session

from dispatch.database.core import get_db
from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.auth.permissions import SensitiveProjectActionPermission, PermissionsDependency
from dispatch.models import PrimaryKey
Expand Down Expand Up @@ -31,7 +30,7 @@ def get_case_severities(*, common: dict = Depends(common_parameters)):
)
def create_case_severity(
*,
db_session: Session = Depends(get_db),
db_session: DbSession,
case_severity_in: CaseSeverityCreate,
):
"""Creates a new case severity."""
Expand All @@ -46,7 +45,7 @@ def create_case_severity(
)
def update_case_severity(
*,
db_session: Session = Depends(get_db),
db_session: DbSession,
case_severity_id: PrimaryKey,
case_severity_in: CaseSeverityUpdate,
):
Expand All @@ -67,7 +66,7 @@ def update_case_severity(


@router.get("/{case_severity_id}", response_model=CaseSeverityRead)
def get_case_severity(*, db_session: Session = Depends(get_db), case_severity_id: PrimaryKey):
def get_case_severity(*, db_session: DbSession, case_severity_id: PrimaryKey):
"""Gets a case severity."""
case_severity = get(db_session=db_session, case_severity_id=case_severity_id)
if not case_severity:
Expand Down
9 changes: 4 additions & 5 deletions src/dispatch/case/type/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session

from dispatch.auth.permissions import SensitiveProjectActionPermission, PermissionsDependency
from dispatch.database.core import get_db
from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.models import PrimaryKey

Expand All @@ -26,7 +25,7 @@ def get_case_types(*, common: dict = Depends(common_parameters)):
)
def create_case_type(
*,
db_session: Session = Depends(get_db),
db_session: DbSession,
case_type_in: CaseTypeCreate,
):
"""Creates a new case type."""
Expand All @@ -40,7 +39,7 @@ def create_case_type(
)
def update_case_type(
*,
db_session: Session = Depends(get_db),
db_session: DbSession,
case_type_id: PrimaryKey,
case_type_in: CaseTypeUpdate,
):
Expand All @@ -55,7 +54,7 @@ def update_case_type(


@router.get("/{case_type_id}", response_model=CaseTypeRead)
def get_case_type(*, db_session: Session = Depends(get_db), case_type_id: PrimaryKey):
def get_case_type(*, db_session: DbSession, case_type_id: PrimaryKey):
"""Gets a case type."""
case_type = get(db_session=db_session, case_type_id=case_type_id)
if not case_type:
Expand Down
42 changes: 19 additions & 23 deletions src/dispatch/case/views.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging
from typing import List
from typing import Annotated, List

import json

from starlette.requests import Request
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status

from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

# NOTE: define permissions before enabling the code block below
from dispatch.auth.permissions import (
Expand All @@ -16,11 +15,10 @@
PermissionsDependency,
CaseViewPermission,
)
from dispatch.auth import service as auth_service
from dispatch.auth.models import DispatchUser
from dispatch.auth.service import CurrentUser
from dispatch.case.enums import CaseStatus
from dispatch.common.utils.views import create_pydantic_include
from dispatch.database.core import get_db
from dispatch.database.core import DbSession
from dispatch.database.service import common_parameters, search_filter_sort_paginate
from dispatch.models import OrganizationSlug, PrimaryKey
from dispatch.incident.models import IncidentCreate, IncidentRead
Expand All @@ -46,7 +44,7 @@
router = APIRouter()


def get_current_case(*, db_session: Session = Depends(get_db), request: Request) -> Case:
def get_current_case(db_session: DbSession, request: Request) -> Case:
"""Fetches a case or returns an HTTP 404."""
case = get(db_session=db_session, case_id=request.path_params["case_id"])
if not case:
Expand All @@ -57,17 +55,19 @@ def get_current_case(*, db_session: Session = Depends(get_db), request: Request)
return case


CurrentCase = Annotated[Case, Depends(get_current_case)]


@router.get(
"/{case_id}",
response_model=CaseRead,
summary="Retrieves a single case.",
dependencies=[Depends(PermissionsDependency([CaseViewPermission]))],
)
def get_case(
*,
case_id: PrimaryKey,
db_session: Session = Depends(get_db),
current_case: Case = Depends(get_current_case),
db_session: DbSession,
current_case: CurrentCase,
):
"""Retrieves the details of a single case."""
return current_case
Expand Down Expand Up @@ -98,11 +98,10 @@ def get_cases(

@router.post("", response_model=CaseRead, summary="Creates a new case.")
def create_case(
*,
db_session: Session = Depends(get_db),
db_session: DbSession,
organization: OrganizationSlug,
case_in: CaseCreate,
current_user: DispatchUser = Depends(auth_service.get_current_user),
current_user: CurrentUser,
background_tasks: BackgroundTasks,
):
"""Creates a new case."""
Expand Down Expand Up @@ -149,13 +148,12 @@ def create_case(
dependencies=[Depends(PermissionsDependency([CaseEditPermission]))],
)
def update_case(
*,
db_session: Session = Depends(get_db),
current_case: Case = Depends(get_current_case),
db_session: DbSession,
current_case: CurrentCase,
organization: OrganizationSlug,
case_id: PrimaryKey,
case_in: CaseUpdate,
current_user: DispatchUser = Depends(auth_service.get_current_user),
current_user: CurrentUser,
background_tasks: BackgroundTasks,
):
"""Updates an existing case."""
Expand Down Expand Up @@ -187,12 +185,11 @@ def update_case(
dependencies=[Depends(PermissionsDependency([CaseEditPermission]))],
)
def escalate_case(
*,
db_session: Session = Depends(get_db),
current_case: Case = Depends(get_current_case),
db_session: DbSession,
current_case: CurrentCase,
organization: OrganizationSlug,
incident_in: IncidentCreate,
current_user: DispatchUser = Depends(auth_service.get_current_user),
current_user: CurrentUser,
background_tasks: BackgroundTasks,
):
"""Escalates an existing case."""
Expand Down Expand Up @@ -229,10 +226,9 @@ def escalate_case(
dependencies=[Depends(PermissionsDependency([CaseEditPermission]))],
)
def delete_case(
*,
case_id: PrimaryKey,
db_session: Session = Depends(get_db),
current_case: Case = Depends(get_current_case),
db_session: DbSession,
current_case: CurrentCase,
):
"""Deletes an existing case and its external resources."""
# we run the case delete flow
Expand Down
15 changes: 6 additions & 9 deletions src/dispatch/data/alert/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from fastapi import APIRouter, HTTPException, status

from dispatch.database.core import get_db
from dispatch.database.core import DbSession
from dispatch.models import PrimaryKey

from .models import (
Expand All @@ -15,7 +14,7 @@


@router.get("/{alert_id}", response_model=AlertRead)
def get_alert(*, db_session: Session = Depends(get_db), alert_id: PrimaryKey):
def get_alert(*, db_session: DbSession, alert_id: PrimaryKey):
"""Given its unique id, retrieve details about a single alert."""
alert = get(db_session=db_session, alert_id=alert_id)
if not alert:
Expand All @@ -27,15 +26,13 @@ def get_alert(*, db_session: Session = Depends(get_db), alert_id: PrimaryKey):


@router.post("", response_model=AlertRead)
def create_alert(*, db_session: Session = Depends(get_db), alert_in: AlertCreate):
def create_alert(*, db_session: DbSession, alert_in: AlertCreate):
"""Creates a new alert."""
return create(db_session=db_session, alert_in=alert_in)


@router.put("/{alert_id}", response_model=AlertRead)
def update_alert(
*, db_session: Session = Depends(get_db), alert_id: PrimaryKey, alert_in: AlertUpdate
):
def update_alert(*, db_session: DbSession, alert_id: PrimaryKey, alert_in: AlertUpdate):
"""Updates an alert."""
alert = get(db_session=db_session, alert_id=alert_id)
if not alert:
Expand All @@ -47,7 +44,7 @@ def update_alert(


@router.delete("/{alert_id}", response_model=None)
def delete_alert(*, db_session: Session = Depends(get_db), alert_id: PrimaryKey):
def delete_alert(*, db_session: DbSession, alert_id: PrimaryKey):
"""Deletes an alert, returning only an HTTP 200 OK if successful."""
alert = get(db_session=db_session, alert_id=alert_id)
if not alert:
Expand Down
Loading