Skip to content

Commit

Permalink
feat: preload textual model
Browse files Browse the repository at this point in the history
  • Loading branch information
martabal committed Sep 16, 2024
1 parent 4735db8 commit 708a53a
Show file tree
Hide file tree
Showing 17 changed files with 301 additions and 19 deletions.
28 changes: 27 additions & 1 deletion machine-learning/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zipfile import BadZipFile

import orjson
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi import Depends, FastAPI, File, Form, HTTPException, Response
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
Expand All @@ -28,6 +28,7 @@
InferenceEntries,
InferenceEntry,
InferenceResponse,
LoadModelEntry,
MessageResponse,
ModelFormat,
ModelIdentity,
Expand Down Expand Up @@ -124,6 +125,24 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
raise HTTPException(422, "Invalid request format.")


def get_entry(entries: str = Form()) -> LoadModelEntry:
try:
request: PipelineRequest = orjson.loads(entries)
for task, types in request.items():
for type, entry in types.items():
parsed: LoadModelEntry = {
"name": entry["modelName"],
"task": task,
"type": type,
"options": entry.get("options", {}),
"ttl": entry["ttl"] if "ttl" in entry else settings.ttl,
}
return parsed
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
log.error(f"Invalid request format: {e}")
raise HTTPException(422, "Invalid request format.")


app = FastAPI(lifespan=lifespan)


Expand All @@ -137,6 +156,13 @@ def ping() -> str:
return "pong"


@app.post("/load", response_model=TextResponse)
async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
model = await load(model)
return Response(status_code=200)


@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
entries: InferenceEntries = Depends(get_entries),
Expand Down
11 changes: 11 additions & 0 deletions machine-learning/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ class InferenceEntry(TypedDict):
options: dict[str, Any]


class LoadModelEntry(InferenceEntry):
ttl: int

def __init__(self, name: str, task: ModelTask, type: ModelType, options: dict[str, Any], ttl: int):
super().__init__(name=name, task=task, type=type, options=options)

if ttl <= 0:
raise ValueError("ttl must be a positive integer")
self.ttl = ttl


InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]


Expand Down
1 change: 1 addition & 0 deletions mobile/openapi/README.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mobile/openapi/lib/api.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions mobile/openapi/lib/api_client.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion mobile/openapi/lib/model/clip_config.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

107 changes: 107 additions & 0 deletions mobile/openapi/lib/model/load_textual_model_on_connection.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions open-api/immich-openapi-specs.json
Original file line number Diff line number Diff line change
Expand Up @@ -8603,12 +8603,16 @@
"enabled": {
"type": "boolean"
},
"loadTextualModelOnConnection": {
"$ref": "#/components/schemas/LoadTextualModelOnConnection"
},
"modelName": {
"type": "string"
}
},
"required": [
"enabled",
"loadTextualModelOnConnection",
"modelName"
],
"type": "object"
Expand Down Expand Up @@ -9433,6 +9437,23 @@
],
"type": "object"
},
"LoadTextualModelOnConnection": {
"properties": {
"enabled": {
"type": "boolean"
},
"ttl": {
"format": "int64",
"minimum": 0,
"type": "number"
}
},
"required": [
"enabled",
"ttl"
],
"type": "object"
},
"LogLevel": {
"enum": [
"verbose",
Expand Down
5 changes: 5 additions & 0 deletions open-api/typescript-sdk/src/fetch-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1100,8 +1100,13 @@ export type SystemConfigLoggingDto = {
enabled: boolean;
level: LogLevel;
};
export type LoadTextualModelOnConnection = {
enabled: boolean;
ttl: number;
};
export type ClipConfig = {
enabled: boolean;
loadTextualModelOnConnection: LoadTextualModelOnConnection;
modelName: string;
};
export type DuplicateDetectionConfig = {
Expand Down
8 changes: 8 additions & 0 deletions server/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ export interface SystemConfig {
clip: {
enabled: boolean;
modelName: string;
loadTextualModelOnConnection: {
enabled: boolean;
ttl: number;
};
};
duplicateDetection: {
enabled: boolean;
Expand Down Expand Up @@ -270,6 +274,10 @@ export const defaults = Object.freeze<SystemConfig>({
clip: {
enabled: true,
modelName: 'ViT-B-32__openai',
loadTextualModelOnConnection: {
enabled: false,
ttl: 300,
},
},
duplicateDetection: {
enabled: true,
Expand Down
17 changes: 15 additions & 2 deletions server/src/dtos/model-config.dto.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
import { IsNotEmpty, IsNumber, IsObject, IsString, Max, Min, ValidateNested } from 'class-validator';
import { ValidateBoolean } from 'src/validation';

export class TaskConfig {
Expand All @@ -14,7 +14,20 @@ export class ModelConfig extends TaskConfig {
modelName!: string;
}

export class CLIPConfig extends ModelConfig {}
export class LoadTextualModelOnConnection extends TaskConfig {
@IsNumber()
@Min(0)
@Type(() => Number)
@ApiProperty({ type: 'number', format: 'int64' })
ttl!: number;
}

export class CLIPConfig extends ModelConfig {
@Type(() => LoadTextualModelOnConnection)
@ValidateNested()
@IsObject()
loadTextualModelOnConnection!: LoadTextualModelOnConnection;
}

export class DuplicateDetectionConfig extends TaskConfig {
@IsNumber()
Expand Down
7 changes: 6 additions & 1 deletion server/src/interfaces/machine-learning.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ export type ModelPayload = { imagePath: string } | { text: string };

type ModelOptions = { modelName: string };

export interface LoadModelOptions extends ModelOptions {
ttl: number;
}

export type FaceDetectionOptions = ModelOptions & { minScore: number };

type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;

export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions | LoadModelOptions } };
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };

export type FacialRecognitionRequest = {
Expand All @@ -54,4 +58,5 @@ export interface IMachineLearningRepository {
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
loadTextModel(url: string, config: ModelOptions): Promise<void>;
}
Loading

0 comments on commit 708a53a

Please sign in to comment.