Skip to content

Commit

Permalink
style(exc): merge exc to one file
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxin688 committed Jan 28, 2024
1 parent 96c29ef commit 97b16bc
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 127 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fixable = ["ALL"]
[tool.ruff.extend-per-file-ignores]
"env.py" = ["INP001", "I001", "ERA001"]
"tests/*.py" = ["S101"]
"exception_handlers.py" = ["ARG001"]
"exceptions.py" = ["ARG001"]
"models.py" = ["RUF012"]
"api.py" = ["A002", "B008"]
"deps.py" = ["B008"]
Expand Down
3 changes: 1 addition & 2 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

from src.config import settings
from src.enums import Env
from src.exception_handlers import default_exception_handler, exception_handlers
from src.exceptions import sentry_ignore_errors
from src.exceptions import default_exception_handler, exception_handlers, sentry_ignore_errors
from src.middlewares import RequestMiddleware
from src.openapi import openapi_description
from src.routers import router
Expand Down
32 changes: 19 additions & 13 deletions src/db/mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import datetime
from typing import TYPE_CHECKING

from fastapi.encoders import jsonable_encoder
Expand All @@ -14,12 +13,15 @@
from src.db.base import Base

if TYPE_CHECKING:
from datetime import datetime

from src.auth.models import User
from src.db.dtobase import ModelT


class AuditTimeMixin:
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=func.now(), index=True)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
created_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now(), index=True)
updated_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now(), onupdate=func.now())


def get_object_change(obj: Mapper) -> dict:
Expand All @@ -29,6 +31,8 @@ def get_object_change(obj: Mapper) -> dict:
"diff": {},
}
for attr in class_mapper(obj.__class__).column_attrs:
before = None
after = None
if getattr(insp.attrs, attr.key).hisotry.has_changes():
if get_history(obj, attr.key)[2]:
before = get_history(obj, attr.key)[2].pop()
Expand All @@ -38,16 +42,15 @@ def get_object_change(obj: Mapper) -> dict:
after = getattr(obj, attr.key)
if before != after:
changes["diff"][attr.key] = {"before": before, "after": after}
return jsonable_encoder(changes)
return jsonable_encoder(changes, exclude={"children", "parent"})


class AuditLog:
id: Mapped[int_pk]
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=func.now())
created_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now())
request_id: Mapped[str]
action: Mapped[str] = mapped_column(String, nullable=False)
diff: Mapped[dict | None] = mapped_column(JSON)
post_change: Mapped[dict | None] = mapped_column(JSON)

@declared_attr
@classmethod
Expand All @@ -68,7 +71,7 @@ def user(cls) -> Mapped["User"]:
class AuditLogMixin:
@declared_attr
@classmethod
def audit_log(cls) -> Mapped[list["AuditLog"]]:
def audit_log(cls: type["ModelT"]) -> Mapped[list["AuditLog"]]:
cls.AuditLog = type(
f"{cls.__name__}AuditLog",
(AuditLog, Base),
Expand All @@ -85,20 +88,20 @@ def audit_log(cls) -> Mapped[list["AuditLog"]]:
return relationship(cls.AuditLog)

@classmethod
def log_create(cls, mapper: Mapper, connection: Connection, target: Mapper) -> None: # noqa: ARG003
def log_create(cls, mapper: Mapper, connection: Connection, target: "ModelT") -> None: # noqa: ARG003
connection.execute(
insert(cls.AuditLog),
{
"request_id": request_id_ctx.get(),
"action": "create",
"post_change": target.dict(exclude_relationship=True),
"diff": target.dict(native_dict=True),
"parent_id": target.id,
"user_id": user_ctx.get(),
},
)

@classmethod
def log_update(cls, mapper: Mapper, connection: Connection, target: Mapper) -> None: # noqa: ARG003
def log_update(cls, mapper: Mapper, connection: Connection, target: "ModelT") -> None: # noqa: ARG003
changes = get_object_change(target)
if changes is not None:
orm_diff_ctx.set(changes)
Expand All @@ -114,13 +117,13 @@ def log_update(cls, mapper: Mapper, connection: Connection, target: Mapper) -> N
)

@classmethod
def log_delete(cls, mapper: Mapper, connection: Connection, target: Mapper) -> None: # noqa: ARG003
def log_delete(cls, mapper: Mapper, connection: Connection, target: "ModelT") -> None: # noqa: ARG003
connection.execute(
insert(cls.AuditLog),
{
"request_id": request_id_ctx.get(),
"action": "delete",
"diff": target.dict(exclude_relationship=True),
"diff": target.dict(native_dict=True),
"parent_id": target.id,
"user_id": user_ctx.get(),
},
Expand All @@ -134,10 +137,13 @@ def __declare_last__(cls) -> None:


class AuditUserMixin:
created_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now(), index=True)
updated_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now(), onupdate=func.now())

@declared_attr
@classmethod
def created_by_fk(cls) -> Mapped[int | None]:
return mapped_column(Integer, ForeignKey("user.id"), default=user_ctx.get, nullable=True)
return mapped_column(Integer, ForeignKey("user.id"), default=user_ctx.get)

@declared_attr
@classmethod
Expand Down
108 changes: 0 additions & 108 deletions src/exception_handlers.py

This file was deleted.

107 changes: 104 additions & 3 deletions src/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import logging
import sys
import traceback
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from typing import Any
from typing import Any, NewType
from uuid import UUID

from fastapi import status
from fastapi import Request, status
from fastapi.responses import JSONResponse

from src.context import locale_ctx
from src import errors
from src.context import locale_ctx, request_id_ctx
from src.errors import ErrorCode
from src.i18n import _

_E = NewType("_E", Exception)
logger = logging.getLogger(__name__)


def error_message_value_handler(value: Any) -> Any:
Expand Down Expand Up @@ -69,6 +78,98 @@ def __repr__(self) -> str:
return f"Gener Error Occurred: ErrCode: {self.error.error}, Message: {self.error.message}"


def log_exception(exc: type[BaseException] | Exception, logger_trace_info: bool) -> None:
"""
Logs an exception.
Args:
exc (Type[BaseException] | Exception): The exception to be logged.
logger_trace_info (bool): Indicates whether to include detailed trace information in the log.
Returns:
None
Raises:
N/A
"""
logger = logging.getLogger(__name__)
ex_type, _, ex_traceback = sys.exc_info()
trace_back = traceback.format_list(traceback.extract_tb(ex_traceback)[-1:])[-1]

logger.warning(f"ErrorMessage: {exc!s}")
logger.warning(f"Exception Type {ex_type.__name__}: ")

if not logger_trace_info:
logger.warning(f"Stack trace: {trace_back}")
else:
logger.exception(f"Stack trace: {trace_back}")


async def token_invalid_handler(request: Request, exc: TokenInvalidError) -> JSONResponse:
log_exception(exc, False)
response_content = errors.ERR_10002.dict()
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=response_content)


async def invalid_token_for_refresh_handler(request: Request, exc: TokenInvalidForRefreshError) -> JSONResponse:
log_exception(exc, False)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=errors.ERR_10004.dict())


async def token_expired_handler(request: Request, exc: TokenExpireError) -> JSONResponse:
log_exception(exc, False)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=errors.ERR_10003.dict())


async def permission_deny_handler(request: Request, exc: PermissionDenyError) -> JSONResponse:
log_exception(exc, False)
return JSONResponse(status_code=status.HTTP_403_FORBIDDEN, content=errors.ERR_10004.dict())


async def resource_not_found_handler(request: Request, exc: NotFoundError) -> JSONResponse:
log_exception(exc, True)
error_message = _(errors.ERR_404.message, name=exc.name, filed=exc.field, value=exc.value)
content = {"error": errors.ERR_404.error, "message": error_message}
return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=content)


async def resource_exist_handler(request: Request, exc: ExistError) -> JSONResponse:
log_exception(exc, True)
error_message = _(errors.ERR_409.message, name=exc.name, filed=exc.field, value=exc.value)
content = {"error": errors.ERR_409.error, "message": error_message}
return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=content)


def gener_error_handler(request: Request, exc: GenerError) -> JSONResponse:
log_exception(exc, True)
return JSONResponse(
status_code=exc.status_code,
content={
"error": exc.error.error,
"message": _(exc.error.message, **exc.params) if exc.params else _(exc.error.message),
},
)


def default_exception_handler(request: Request, exc: Exception) -> JSONResponse:
log_exception(exc, logger_trace_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": errors.ERR_500.error, "message": _(errors.ERR_500.message, request_id=request_id_ctx.get())},
)


exception_handlers = [
{"exception": TokenInvalidError, "handler": token_invalid_handler},
{"exception": TokenExpireError, "handler": token_expired_handler},
{"exception": TokenInvalidForRefreshError, "handler": invalid_token_for_refresh_handler},
{"exception": PermissionDenyError, "handler": permission_deny_handler},
{"exception": NotFoundError, "handler": resource_not_found_handler},
{"exception": ExistError, "handler": resource_exist_handler},
{"exception": GenerError, "handler": gener_error_handler},
]


sentry_ignore_errors = [
TokenExpireError,
TokenInvalidError,
Expand Down

0 comments on commit 97b16bc

Please sign in to comment.