Skip to content

Commit

Permalink
feat: add file paths to data and prediction metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-encord committed Dec 5, 2023
1 parent 5fd48ad commit 30a5ec2
Showing 1 changed file with 52 additions and 7 deletions.
59 changes: 52 additions & 7 deletions src/encord_active/public/active_project.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Union
from typing import Optional, Type, Union
from uuid import UUID, uuid4

import pandas as pd
Expand All @@ -11,13 +12,16 @@

from encord_active.db.models import (
Project,
ProjectAnnotationAnalytics,
ProjectDataAnalytics,
ProjectDataUnitMetadata,
ProjectPrediction,
ProjectPredictionAnalytics,
ProjectTag,
ProjectTaggedDataUnit,
get_engine,
)
from encord_active.lib.common.data_utils import url_to_file_path

_P = Project
_T = ProjectTag
Expand All @@ -29,15 +33,25 @@ class DataUnitItem:
frame: int


AnalyticsModel = Union[Type[ProjectDataAnalytics], Type[ProjectPredictionAnalytics], Type[ProjectAnnotationAnalytics]]


def get_active_engine(path_to_db: Union[str, Path]) -> Engine:
path = Path(path_to_db) if isinstance(path_to_db, str) else path_to_db
return get_engine(path, use_alembic=False)


class ActiveProject:
def __init__(self, engine: Engine, project_name: str):
@classmethod
def from_db_file(cls, db_file_path: Union[str, Path], project_name: str):
path = Path(db_file_path) if isinstance(db_file_path, str) else db_file_path
engine = get_active_engine(db_file_path)
return ActiveProject(engine, project_name, root_path=path.parent)

def __init__(self, engine: Optional[Engine], project_name: str, root_path: Optional[Path] = None):
self._engine = engine
self._project_name = project_name
self._root_path = root_path

with Session(self._engine) as sess:
res = sess.exec(
Expand Down Expand Up @@ -65,7 +79,7 @@ def __init__(self, engine: Engine, project_name: str):
def get_ontology(self) -> OntologyStructure:
return self.ontology

def get_prediction_metrics(self) -> pd.DataFrame:
def get_prediction_metrics(self, include_data_uris: bool = False) -> pd.DataFrame:
if self._model_hash is None:
raise ValueError(f"Project with name {self._project_name} does not have any model predictions")

Expand All @@ -88,17 +102,48 @@ def get_prediction_metrics(self) -> pd.DataFrame:
(P.feature_hash == P.match_feature_hash).cast(Integer).label("true_positive"), # type: ignore
).where(P.project_hash == self.project_hash, P.prediction_hash == self._model_hash)

transform = None
if include_data_uris:
stmt, transform = self._join_path_statement(stmt, P)

df = pd.DataFrame(sess.exec(stmt).all())
df.columns = list(sess.exec(stmt).keys())

if transform is not None:
df = transform(df)

return df

def get_images_metrics(self) -> pd.DataFrame:
def _join_path_statement(self, stmt, base_model: AnalyticsModel):
stmt = stmt.add_columns(ProjectDataUnitMetadata.data_uri, ProjectDataUnitMetadata.data_uri_is_video)
stmt = stmt.join(
ProjectDataUnitMetadata,
onclause=(base_model.du_hash == ProjectDataUnitMetadata.du_hash)
& (base_model.frame == ProjectDataUnitMetadata.frame),
).where(ProjectDataUnitMetadata.project_hash == self.project_hash)

def transform(df):
if self._root_path is None:
raise ValueError(f"Root path is not set. Provide it in the constructor or use `from_db_file`")
df["data_uri"] = df.data_uri.map(partial(url_to_file_path, project_dir=self._root_path))
return df

return stmt, transform

def get_images_metrics(self, *, include_data_uris: bool = False) -> pd.DataFrame:
with Session(self._engine) as sess:
image_metrics = sess.exec(
select(ProjectDataAnalytics).where(ProjectDataAnalytics.project_hash == self.project_hash)
stmt = select(*[c for c in ProjectDataAnalytics.__table__.c]).where( # type: ignore -- hack to get all columns without the pydantic model
ProjectDataAnalytics.project_hash == self.project_hash
)
df = pd.DataFrame([i.dict() for i in image_metrics])
transform = None
if include_data_uris:
stmt, transform = self._join_path_statement(stmt, ProjectDataAnalytics)
image_metrics = sess.exec(stmt).all()
df = pd.DataFrame(image_metrics)
df.columns = list(sess.execute(stmt).keys())

if transform is not None:
df = transform(df)

return df

Expand Down

0 comments on commit 30a5ec2

Please sign in to comment.