Skip to content

Commit

Permalink
chore(ml): update pydantic (immich-app#13230)
Browse files Browse the repository at this point in the history
* update pydantic

* fix typing

* remove unused import

* remove unused schema
  • Loading branch information
mertalev authored Oct 13, 2024
1 parent f29fb16 commit e7397f3
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 82 deletions.
24 changes: 13 additions & 11 deletions machine-learning/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,27 @@
from socket import socket

from gunicorn.arbiter import Arbiter
from pydantic import BaseModel, BaseSettings
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
from rich.console import Console
from rich.logging import RichHandler
from uvicorn import Server
from uvicorn.workers import UvicornWorker


class PreloadModelData(BaseModel):
clip: str | None
facial_recognition: str | None
clip: str | None = None
facial_recognition: str | None = None


class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="MACHINE_LEARNING_",
case_sensitive=False,
env_nested_delimiter="__",
protected_namespaces=("settings_",),
)

cache_folder: Path = Path("/cache")
model_ttl: int = 300
model_ttl_poll_s: int = 10
Expand All @@ -34,23 +42,17 @@ class Settings(BaseSettings):
ann_tuning_level: int = 2
preload: PreloadModelData | None = None

class Config:
env_prefix = "MACHINE_LEARNING_"
case_sensitive = False
env_nested_delimiter = "__"

@property
def device_id(self) -> str:
return os.environ.get("MACHINE_LEARNING_DEVICE_ID", "0")


class LogSettings(BaseSettings):
model_config = SettingsConfigDict(case_sensitive=False)

immich_log_level: str = "info"
no_color: bool = False

class Config:
case_sensitive = False


_clean_name = str.maketrans(":\\/", "___", ".")

Expand Down
16 changes: 7 additions & 9 deletions machine-learning/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import orjson
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi.responses import ORJSONResponse
from fastapi.responses import ORJSONResponse, PlainTextResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
from pydantic import ValidationError
Expand All @@ -28,14 +28,12 @@
InferenceEntries,
InferenceEntry,
InferenceResponse,
MessageResponse,
ModelFormat,
ModelIdentity,
ModelTask,
ModelType,
PipelineRequest,
T,
TextResponse,
)

MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
Expand Down Expand Up @@ -127,14 +125,14 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
app = FastAPI(lifespan=lifespan)


@app.get("/", response_model=MessageResponse)
async def root() -> dict[str, str]:
return {"message": "Immich ML"}
@app.get("/")
async def root() -> ORJSONResponse:
return ORJSONResponse({"message": "Immich ML"})


@app.get("/ping", response_model=TextResponse)
def ping() -> str:
return "pong"
@app.get("/ping")
def ping() -> PlainTextResponse:
return PlainTextResponse("pong")


@app.post("/predict", dependencies=[Depends(update_state)])
Expand Down
12 changes: 2 additions & 10 deletions machine-learning/app/schemas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from enum import Enum
from typing import Any, Literal, Protocol, TypedDict, TypeGuard, TypeVar
from typing import Any, Literal, Protocol, TypeGuard, TypeVar

import numpy as np
import numpy.typing as npt
from pydantic import BaseModel
from typing_extensions import TypedDict


class StrEnum(str, Enum):
Expand All @@ -13,14 +13,6 @@ def __str__(self) -> str:
return self.value


class TextResponse(BaseModel):
__root__: str


class MessageResponse(BaseModel):
message: str


class BoundingBox(TypedDict):
x1: int
y1: int
Expand Down
17 changes: 16 additions & 1 deletion machine-learning/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,11 +810,26 @@ async def test_falls_back_to_onnx_if_other_format_does_not_exist(
mock_model.model_format = ModelFormat.ONNX


def test_root_endpoint(deployed_app: TestClient) -> None:
response = deployed_app.get("http://localhost:3003")

body = response.json()
assert response.status_code == 200
assert body == {"message": "Immich ML"}


def test_ping_endpoint(deployed_app: TestClient) -> None:
response = deployed_app.get("http://localhost:3003/ping")

assert response.status_code == 200
assert response.text == "pong"


@pytest.mark.skipif(
not settings.test_full,
reason="More time-consuming since it deploys the app and loads models.",
)
class TestEndpoints:
class TestPredictionEndpoints:
def test_clip_image_endpoint(
self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient
) -> None:
Expand Down
Loading

0 comments on commit e7397f3

Please sign in to comment.