Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,28 +1,81 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from functools import lru_cache

from anomalib.models import list_models
from fastapi import APIRouter

from api.endpoints import API_PREFIX
from pydantic_models import TrainableModelList
from pydantic_models import ModelFamily, ModelInfo, TrainableModelList, TrainingTime

router = APIRouter(prefix=f"{API_PREFIX}/trainable-models", tags=["Trainable Models"])


@lru_cache
def _get_trainable_models() -> TrainableModelList: # pragma: no cover
"""Return list of trainable models with optional descriptions.

The available models are retrieved from ``anomalib.models.list_models``. Currently, only
the model names are returned. Descriptions can be added manually in the
``_MODEL_DESCRIPTIONS`` mapping below.
Currently hardcoded for v1.
"""
model_names = sorted(list_models(case="pascal"))
models = [
ModelInfo(
name="PatchCore",
class_name="patchcore",
training_time=TrainingTime.COFFEE,
model_family=[ModelFamily.MEMORY_BANK, ModelFamily.PATCH_BASED],
recommended=True,
),
ModelInfo(
name="FRE",
class_name="fre",
training_time=TrainingTime.COFFEE,
model_family=[ModelFamily.RECONSTRUCTION_BASED],
recommended=True,
),
ModelInfo(
name="Dinomaly",
class_name="dinomaly",
training_time=TrainingTime.CYCLE,
model_family=[ModelFamily.STUDENT_TEACHER, ModelFamily.RECONSTRUCTION_BASED],
recommended=True,
),
ModelInfo(
name="CFA",
class_name="cfa",
training_time=TrainingTime.COFFEE,
model_family=[ModelFamily.MEMORY_BANK],
),
ModelInfo(
name="DFM",
class_name="dfm",
training_time=TrainingTime.COFFEE,
model_family=[ModelFamily.MEMORY_BANK],
),
ModelInfo(
name="FastFlow",
class_name="fastflow",
training_time=TrainingTime.CYCLE,
model_family=[ModelFamily.DISTRIBUTION_MAP],
),
ModelInfo(
name="Padim",
class_name="padim",
training_time=TrainingTime.COFFEE,
model_family=[ModelFamily.MEMORY_BANK, ModelFamily.PATCH_BASED, ModelFamily.DISTRIBUTION_MAP],
),
ModelInfo(
name="Reverse Distillation",
class_name="reverse_distillation",
training_time=TrainingTime.CYCLE,
model_family=[ModelFamily.STUDENT_TEACHER],
),
ModelInfo(
name="SuperSimpleNet",
class_name="supersimplenet",
training_time=TrainingTime.COFFEE,
model_family=[ModelFamily.RECONSTRUCTION_BASED],
),
]

return TrainableModelList(trainable_models=model_names)
return TrainableModelList(trainable_models=models)


@router.get("", summary="List trainable models")
Expand Down
5 changes: 4 additions & 1 deletion application/backend/src/pydantic_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .project import Project, ProjectList, ProjectUpdate
from .sink import DisconnectedSinkConfig, OutputFormat, Sink, SinkType
from .source import DisconnectedSourceConfig, Source, SourceType
from .trainable_model import TrainableModelList
from .trainable_model import ModelFamily, ModelInfo, TrainableModelList, TrainingTime

__all__ = [
"DatasetSnapshot",
Expand All @@ -26,6 +26,8 @@
"Media",
"MediaList",
"Model",
"ModelFamily",
"ModelInfo",
"ModelList",
"OutputFormat",
"Pipeline",
Expand All @@ -42,4 +44,5 @@
"SourceType",
"TimeWindow",
"TrainableModelList",
"TrainingTime",
]
27 changes: 26 additions & 1 deletion application/backend/src/pydantic_models/trainable_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from enum import StrEnum

from pydantic import BaseModel


class ModelFamily(StrEnum):
PATCH_BASED = "patch_based"
MEMORY_BANK = "memory_bank"
STUDENT_TEACHER = "student_teacher"
RECONSTRUCTION_BASED = "reconstruction_based"
DISTRIBUTION_MAP = "distribution_map"


class TrainingTime(StrEnum):
COFFEE = "coffee"
WALK = "walk"
CYCLE = "cycle"


class ModelInfo(BaseModel):
name: str
class_name: str
training_time: TrainingTime
model_family: list[ModelFamily]
recommended: bool = False
license: str = "Apache-2.0"


class TrainableModelList(BaseModel):
"""List wrapper for returning multiple trainable models."""

trainable_models: list[str]
trainable_models: list[ModelInfo]
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,37 @@

from fastapi import status

from api.endpoints.trainable_models_endpoints import _get_trainable_models # noqa: PLC2701
from pydantic_models import ModelFamily, ModelInfo, TrainableModelList, TrainingTime


def test_list_trainable_models(fxt_client):
_get_trainable_models.cache_clear()
# Mock anomalib.models.list_models to return a predictable set
with patch("api.endpoints.trainable_models_endpoints.list_models", return_value={"padim", "patchcore"}):
mock_response = {
"trainable_models": [
{
"name": "Padim",
"class_name": "padim",
"training_time": "coffee",
"recommended": False,
"license": "Apache-2.0",
"model_family": ["memory_bank", "patch_based", "distribution_map"],
},
]
}
with patch(
"api.endpoints.trainable_models_endpoints._get_trainable_models",
return_value=TrainableModelList(
trainable_models=[
ModelInfo(
name="Padim",
class_name="padim",
training_time=TrainingTime.COFFEE,
model_family=[ModelFamily.MEMORY_BANK, ModelFamily.PATCH_BASED, ModelFamily.DISTRIBUTION_MAP],
)
]
),
):
response = fxt_client.get("/api/trainable-models")

assert response.status_code == status.HTTP_200_OK
body = response.json()
assert body == {"trainable_models": ["padim", "patchcore"]} or body == {"trainable_models": ["patchcore", "padim"]}
assert response.json() == mock_response
61 changes: 61 additions & 0 deletions application/ui/src/assets/icons/coffee.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 41 additions & 0 deletions application/ui/src/assets/icons/cycle.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 11 additions & 8 deletions application/ui/src/assets/icons/index.ts
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
export { ReactComponent as ActiveIcon } from './active-icon.svg';
export { ReactComponent as BuildIcon } from './build-icon.svg';
export { ReactComponent as CameraOff } from './camera-off.svg';
export { ReactComponent as Camera } from './camera.svg';
export { ReactComponent as Coffee } from './coffee.svg';
export { ReactComponent as Cycle } from './cycle.svg';
export { ReactComponent as Dataset } from './dataset.svg';
export { ReactComponent as DoubleChevronRightIcon } from './double-chevron-right-icon.svg';
export { ReactComponent as ErrorIcon } from './error-icon.svg';
export { ReactComponent as Fireworks } from './fire-works.svg';
export { ReactComponent as FolderArrowRight } from './folder-arrow-right.svg';
export { ReactComponent as Folder } from './folder.svg';
export { ReactComponent as Genicam } from './genicam.svg';
export { ReactComponent as Image } from './image.svg';
export { ReactComponent as ImagesFolder } from './images-folder.svg';
export { ReactComponent as IpCamera } from './ip-camera.svg';
export { ReactComponent as LinkExpired } from './link-expired.svg';
export { ReactComponent as LiveFeedIcon } from './live-feed-icon.svg';
export { ReactComponent as LoadingIcon } from './loading-icon.svg';
export { ReactComponent as Models } from './models.svg';
export { ReactComponent as Mqtt } from './mqtt.svg';
export { ReactComponent as NotFoundIcon } from './not-found.svg';
export { ReactComponent as Onnx } from './onnx.svg';
export { ReactComponent as OpenVino } from './openvino.svg';
export { ReactComponent as PipelineLink } from './pipeline-link.svg';
export { ReactComponent as PyTorch } from './pytorch.svg';
export { ReactComponent as Ros } from './ros.svg';
export { ReactComponent as Stats } from './stats.svg';
export { ReactComponent as SuccessIcon } from './success-icon.svg';
export { ReactComponent as ThreeDotsCircle } from './three-dots-circle.svg';
export { ReactComponent as VideoFile } from './video-file.svg';
export { ReactComponent as Walk } from './walk.svg';
export { ReactComponent as Webcam } from './webcam.svg';
export { ReactComponent as Webhook } from './webhook.svg';
export { ReactComponent as Image } from './image.svg';
export { ReactComponent as Fireworks } from './fire-works.svg';
export { ReactComponent as PipelineLink } from './pipeline-link.svg';
export { ReactComponent as Folder } from './folder.svg';
export { ReactComponent as LinkExpired } from './link-expired.svg';
export { ReactComponent as ActiveIcon } from './active-icon.svg';
export { ReactComponent as LoadingIcon } from './loading-icon.svg';
export { ReactComponent as NotFoundIcon } from './not-found.svg';
41 changes: 41 additions & 0 deletions application/ui/src/assets/icons/walk.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading