Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: preload textual model #12729

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion machine-learning/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zipfile import BadZipFile

import orjson
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi import Depends, FastAPI, File, Form, HTTPException, Response
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
Expand All @@ -28,6 +28,7 @@
InferenceEntries,
InferenceEntry,
InferenceResponse,
LoadModelEntry,
MessageResponse,
ModelFormat,
ModelIdentity,
Expand Down Expand Up @@ -124,6 +125,24 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
raise HTTPException(422, "Invalid request format.")


def get_entry(entries: str = Form()) -> LoadModelEntry:
try:
request: PipelineRequest = orjson.loads(entries)
for task, types in request.items():
for type, entry in types.items():
parsed: LoadModelEntry = {
"name": entry["modelName"],
"task": task,
"type": type,
"options": entry.get("options", {}),
"ttl": entry["ttl"] if "ttl" in entry else settings.ttl,
}
return parsed
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
log.error(f"Invalid request format: {e}")
raise HTTPException(422, "Invalid request format.")


app = FastAPI(lifespan=lifespan)


Expand All @@ -137,6 +156,13 @@ def ping() -> str:
return "pong"


@app.post("/load", response_model=TextResponse)
async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
model = await load(model)
return Response(status_code=200)


@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
entries: InferenceEntries = Depends(get_entries),
Expand Down
11 changes: 11 additions & 0 deletions machine-learning/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ class InferenceEntry(TypedDict):
options: dict[str, Any]


class LoadModelEntry(InferenceEntry):
ttl: int

def __init__(self, name: str, task: ModelTask, type: ModelType, options: dict[str, Any], ttl: int):
super().__init__(name=name, task=task, type=type, options=options)

if ttl <= 0:
raise ValueError("ttl must be a positive integer")
self.ttl = ttl
Copy link
Contributor

@mertalev mertalev Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferenceEntry is just a type for a dict, so this __init__ shouldn't exist.

Also, I have some reservations about the ttl field here. TTL is currently set by an env and assumed to be the same for all models. At minimum, you would need to change the idle shutdown conditions to not rely on the env anymore and instead (synchronously) check if the model cache is empty.

Moreover, I'm not sure if this is something that the caller should be able to decide. The relationship between server and ML isn't necessarily one-to-one, so the settings should be designed with that in mind. If you have multiple servers that share a machine learning service, the effective TTL now depends on the order of requests. TTL is also relevant to whoever is deploying ML and may not be something they want a caller to be able to configure.



InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]


Expand Down
1 change: 1 addition & 0 deletions mobile/openapi/README.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mobile/openapi/lib/api.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions mobile/openapi/lib/api_client.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion mobile/openapi/lib/model/clip_config.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

107 changes: 107 additions & 0 deletions mobile/openapi/lib/model/load_textual_model_on_connection.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions open-api/immich-openapi-specs.json
Original file line number Diff line number Diff line change
Expand Up @@ -8603,12 +8603,16 @@
"enabled": {
"type": "boolean"
},
"loadTextualModelOnConnection": {
"$ref": "#/components/schemas/LoadTextualModelOnConnection"
},
"modelName": {
"type": "string"
}
},
"required": [
"enabled",
"loadTextualModelOnConnection",
"modelName"
],
"type": "object"
Expand Down Expand Up @@ -9433,6 +9437,23 @@
],
"type": "object"
},
"LoadTextualModelOnConnection": {
"properties": {
"enabled": {
"type": "boolean"
},
"ttl": {
"format": "int64",
"minimum": 0,
"type": "number"
}
},
"required": [
"enabled",
"ttl"
],
"type": "object"
},
"LogLevel": {
"enum": [
"verbose",
Expand Down
5 changes: 5 additions & 0 deletions open-api/typescript-sdk/src/fetch-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1100,8 +1100,13 @@ export type SystemConfigLoggingDto = {
enabled: boolean;
level: LogLevel;
};
export type LoadTextualModelOnConnection = {
enabled: boolean;
ttl: number;
};
export type ClipConfig = {
enabled: boolean;
loadTextualModelOnConnection: LoadTextualModelOnConnection;
modelName: string;
};
export type DuplicateDetectionConfig = {
Expand Down
8 changes: 8 additions & 0 deletions server/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ export interface SystemConfig {
clip: {
enabled: boolean;
modelName: string;
loadTextualModelOnConnection: {
enabled: boolean;
ttl: number;
};
};
duplicateDetection: {
enabled: boolean;
Expand Down Expand Up @@ -270,6 +274,10 @@ export const defaults = Object.freeze<SystemConfig>({
clip: {
enabled: true,
modelName: 'ViT-B-32__openai',
loadTextualModelOnConnection: {
enabled: false,
ttl: 300,
},
},
duplicateDetection: {
enabled: true,
Expand Down
17 changes: 15 additions & 2 deletions server/src/dtos/model-config.dto.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
import { IsNotEmpty, IsNumber, IsObject, IsString, Max, Min, ValidateNested } from 'class-validator';
import { ValidateBoolean } from 'src/validation';

export class TaskConfig {
Expand All @@ -14,7 +14,20 @@ export class ModelConfig extends TaskConfig {
modelName!: string;
}

export class CLIPConfig extends ModelConfig {}
export class LoadTextualModelOnConnection extends TaskConfig {
@IsNumber()
@Min(0)
@Type(() => Number)
@ApiProperty({ type: 'number', format: 'int64' })
ttl!: number;
}

export class CLIPConfig extends ModelConfig {
@Type(() => LoadTextualModelOnConnection)
@ValidateNested()
@IsObject()
loadTextualModelOnConnection!: LoadTextualModelOnConnection;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be less wordy? Maybe preloadTextualModel?

Also, I imagine there are a lot of sessions where nothing gets searched and the model is loaded unnecessarily. If this were instead done when the user clicks the search bar, it'd almost always be a positive and could be enabled by default.

}

export class DuplicateDetectionConfig extends TaskConfig {
@IsNumber()
Expand Down
7 changes: 6 additions & 1 deletion server/src/interfaces/machine-learning.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ export type ModelPayload = { imagePath: string } | { text: string };

type ModelOptions = { modelName: string };

export interface LoadModelOptions extends ModelOptions {
ttl: number;
}

export type FaceDetectionOptions = ModelOptions & { minScore: number };

type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;

export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions | LoadModelOptions } };
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };

export type FacialRecognitionRequest = {
Expand All @@ -54,4 +58,5 @@ export interface IMachineLearningRepository {
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
loadTextModel(url: string, config: ModelOptions): Promise<void>;
}
Loading
Loading