Skip to content

Commit

Permalink
Pydantic v2 & FastAPI Annotated (BC-SECURITY#727)
Browse files Browse the repository at this point in the history
* bump deps

* convert but still have some failing tests

* add validators to get v1 functionality

* remove unnecessary changes

* some more fixes

* remove a unnecessary change

* Initial conversion to use Annotated

* Changelog
  • Loading branch information
vinnybod authored Nov 8, 2023
1 parent 957617e commit 3ed4c67
Show file tree
Hide file tree
Showing 43 changed files with 522 additions and 377 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- Upgrade Pydantic to v2 (@Vinnybod)
- Update common FastAPI Dependencies to use 'Annotated' types for simpler code (@Vinnybod)

## [5.8.0] - 2023-11-06

- Warning: You may run into errors installing things such as nim if you are running the install script on a machine that previously ran it. This is due to permissions changes with the install script. In this case it is recommended to use a fresh machine or manually remove the offending directories/files.
Expand Down
19 changes: 15 additions & 4 deletions empire/server/api/jwt_auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timedelta
from typing import Annotated

from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
Expand All @@ -8,7 +9,7 @@
from sqlalchemy.orm import Session
from starlette import status

from empire.server.api.v2.shared_dependencies import get_db
from empire.server.api.v2.shared_dependencies import CurrentSession
from empire.server.core.db import models
from empire.server.core.db.base import SessionLocal

Expand Down Expand Up @@ -69,7 +70,8 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):


async def get_current_user(
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
db: CurrentSession,
token: str = Depends(oauth2_scheme),
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand All @@ -90,19 +92,28 @@ async def get_current_user(
return user


CurrentUser = Annotated[models.User, Depends(get_current_user)]


async def get_current_active_user(
current_user: models.User = Depends(get_current_user),
current_user: CurrentUser,
):
if not current_user.enabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user


CurrentActiveUser = Annotated[models.User, Depends(get_current_active_user)]


async def get_current_active_admin_user(
current_user: models.User = Depends(get_current_user),
current_user: CurrentUser,
):
if not current_user.enabled:
raise HTTPException(status_code=400, detail="Inactive user")
if not current_user.admin:
raise HTTPException(status_code=403, detail="Not an admin user")
return current_user


CurrentActiveAdminUser = Annotated[models.User, Depends(get_current_active_admin_user)]
15 changes: 7 additions & 8 deletions empire/server/api/v2/agent/agent_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from datetime import datetime

from fastapi import Depends, HTTPException, Query
from sqlalchemy.orm import Session

from empire.server.api.api_router import APIRouter
from empire.server.api.jwt_auth import get_current_active_user
Expand All @@ -17,7 +16,7 @@
domain_to_dto_agent_checkin,
domain_to_dto_agent_checkin_agg,
)
from empire.server.api.v2.shared_dependencies import get_db
from empire.server.api.v2.shared_dependencies import CurrentSession
from empire.server.api.v2.shared_dto import (
BadRequestResponse,
NotFoundResponse,
Expand All @@ -41,7 +40,7 @@
)


async def get_agent(uid: str, db: Session = Depends(get_db)):
async def get_agent(uid: str, db: CurrentSession):
agent = agent_service.get_by_id(db, uid)

if agent:
Expand All @@ -55,7 +54,7 @@ async def get_agent(uid: str, db: Session = Depends(get_db)):

@router.get("/checkins", response_model=AgentCheckIns)
def read_agent_checkins_all(
db: Session = Depends(get_db),
db: CurrentSession,
agents: list[str] = Query(None),
limit: int = 1000,
page: int = 1,
Expand All @@ -79,7 +78,7 @@ def read_agent_checkins_all(

@router.get("/checkins/aggregate", response_model=AgentCheckInsAggregate)
def read_agent_checkins_aggregate(
db: Session = Depends(get_db),
db: CurrentSession,
agents: list[str] = Query(None),
start_date: datetime | None = None,
end_date: datetime | None = None,
Expand Down Expand Up @@ -111,7 +110,7 @@ async def read_agent(uid: str, db_agent: models.Agent = Depends(get_agent)):

@router.get("/", response_model=Agents)
async def read_agents(
db: Session = Depends(get_db),
db: CurrentSession,
include_archived: bool = False,
include_stale: bool = True,
):
Expand All @@ -129,7 +128,7 @@ async def read_agents(
async def update_agent(
uid: str,
agent_req: AgentUpdateRequest,
db: Session = Depends(get_db),
db: CurrentSession,
db_agent: models.Agent = Depends(get_agent),
):
resp, err = agent_service.update_agent(db, db_agent, agent_req)
Expand All @@ -142,7 +141,7 @@ async def update_agent(

@router.get("/{uid}/checkins", response_model=AgentCheckIns)
def read_agent_checkins(
db: Session = Depends(get_db),
db: CurrentSession,
db_agent: models.Agent = Depends(get_agent),
limit: int = -1,
page: int = 1,
Expand Down
48 changes: 24 additions & 24 deletions empire/server/api/v2/agent/agent_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,35 +83,35 @@ class Agent(BaseModel):
name: str
# listener_id: int
listener: str
host_id: int | None
hostname: str | None
language: str | None
language_version: str | None
host_id: int | None = None
hostname: str | None = None
language: str | None = None
language_version: str | None = None
delay: int
jitter: float
external_ip: str | None
internal_ip: str | None
username: str | None
high_integrity: bool | None
process_id: int | None
process_name: str | None
os_details: str | None
external_ip: str | None = None
internal_ip: str | None = None
username: str | None = None
high_integrity: bool | None = None
process_id: int | None = None
process_name: str | None = None
os_details: str | None = None
nonce: str
checkin_time: datetime
lastseen_time: datetime
parent: str | None
children: str | None
servers: str | None
profile: str | None
functions: str | None
kill_date: str | None
working_hours: str | None
parent: str | None = None
children: str | None = None
servers: str | None = None
profile: str | None = None
functions: str | None = None
kill_date: str | None = None
working_hours: str | None = None
lost_limit: int
notes: str | None
architecture: str | None
notes: str | None = None
architecture: str | None = None
archived: bool
stale: bool
proxies: dict | None
proxies: dict | None = None
tags: list[Tag]


Expand Down Expand Up @@ -139,8 +139,8 @@ class AgentCheckInAggregate(BaseModel):

class AgentCheckInsAggregate(BaseModel):
records: list[AgentCheckInAggregate]
start_date: datetime | None
end_date: datetime | None
start_date: datetime | None = None
end_date: datetime | None = None
bucket_size: str


Expand All @@ -153,4 +153,4 @@ class AggregateBucket(str, Enum):

class AgentUpdateRequest(BaseModel):
name: str
notes: str | None
notes: str | None = None
15 changes: 6 additions & 9 deletions empire/server/api/v2/agent/agent_file_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from fastapi import Depends, HTTPException
from sqlalchemy.orm import Session

from empire.server.api.api_router import APIRouter
from empire.server.api.jwt_auth import get_current_active_user
from empire.server.api.v2.agent.agent_file_dto import AgentFile, domain_to_dto_file
from empire.server.api.v2.shared_dependencies import get_db
from empire.server.api.v2.shared_dependencies import CurrentSession
from empire.server.api.v2.shared_dto import BadRequestResponse, NotFoundResponse
from empire.server.core.agent_file_service import AgentFileService
from empire.server.core.agent_service import AgentService
Expand All @@ -25,7 +24,7 @@
)


async def get_agent(agent_id: str, db: Session = Depends(get_db)):
async def get_agent(agent_id: str, db: CurrentSession):
agent = agent_service.get_by_id(db, agent_id)

if agent:
Expand All @@ -35,7 +34,7 @@ async def get_agent(agent_id: str, db: Session = Depends(get_db)):


async def get_file(
uid: int, db: Session = Depends(get_db), db_agent: models.Agent = Depends(get_agent)
uid: int, db: CurrentSession, db_agent: models.Agent = Depends(get_agent)
):
file = agent_file_service.get_file(db, db_agent.session_id, uid)

Expand All @@ -47,9 +46,9 @@ async def get_file(
)


@router.get("/root", dependencies=[Depends(get_current_active_user)])
@router.get("/root")
async def read_file_root(
db: Session = Depends(get_db), db_agent: models.Agent = Depends(get_agent)
db: CurrentSession, db_agent: models.Agent = Depends(get_agent)
):
file = agent_file_service.get_file_by_path(db, db_agent.session_id, "/")

Expand All @@ -61,9 +60,7 @@ async def read_file_root(
)


@router.get(
"/{uid}", response_model=AgentFile, dependencies=[Depends(get_current_active_user)]
)
@router.get("/{uid}", response_model=AgentFile)
async def read_file(
uid: int,
db_agent: models.Agent = Depends(get_agent),
Expand Down
8 changes: 3 additions & 5 deletions empire/server/api/v2/agent/agent_file_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://pydantic-docs.helpmanual.io/usage/postponed_annotations/#self-referencing-models
from __future__ import annotations

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from empire.server.api.v2.shared_dto import (
DownloadDescription,
Expand Down Expand Up @@ -32,12 +32,10 @@ class AgentFile(BaseModel):
name: str
path: str
is_file: bool
parent_id: int | None
parent_id: int | None = None
downloads: list[DownloadDescription]
children: list[AgentFile] = []

class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)


AgentFile.update_forward_refs()
Loading

0 comments on commit 3ed4c67

Please sign in to comment.