Skip to content

Commit

Permalink
feat: active project export data (#680)
Browse files Browse the repository at this point in the history
* refactor: clean up data querying

* feat: add file paths to data and prediction metrics

* feat: add support for including tags as well

* fix: linting

* fix: include also data without tags
  • Loading branch information
frederik-encord authored Jan 16, 2024
1 parent 86f9967 commit b4f46ff
Showing 1 changed file with 164 additions and 85 deletions.
249 changes: 164 additions & 85 deletions src/encord_active/public/active_project.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,201 @@
from dataclasses import dataclass
from typing import Any
from pathlib import Path
from typing import Optional, Type, Union
from uuid import UUID, uuid4

import pandas as pd
from encord.objects import OntologyStructure
from sqlalchemy import MetaData, Table, create_engine, select, text
from sqlalchemy import Integer, func
from sqlalchemy.engine import Engine
from sqlalchemy.sql import Select
from sqlmodel import Session
from sqlmodel import select as sqlmodel_select
from sqlmodel import Session, select

from encord_active.db.models import ProjectTag, ProjectTaggedDataUnit
from encord_active.db.models import (
Project,
ProjectAnnotationAnalytics,
ProjectDataAnalytics,
ProjectDataUnitMetadata,
ProjectPrediction,
ProjectPredictionAnalytics,
ProjectTag,
ProjectTaggedDataUnit,
get_engine,
)
from encord_active.lib.common.data_utils import url_to_file_path

_P = Project
_T = ProjectTag


@dataclass
class DataUnitItem:
du_hash: str
du_hash: UUID
frame: int


def get_active_engine(path_to_db: str) -> Engine:
return create_engine(f"sqlite:///{path_to_db}")
AnalyticsModel = Union[Type[ProjectDataAnalytics], Type[ProjectPredictionAnalytics], Type[ProjectAnnotationAnalytics]]


def get_active_engine(path_to_db: Union[str, Path]) -> Engine:
path = Path(path_to_db) if isinstance(path_to_db, str) else path_to_db
return get_engine(path, use_alembic=False)


class ActiveProject:
def __init__(self, engine: Engine, project_name: str):
@classmethod
def from_db_file(cls, db_file_path: Union[str, Path], project_name: str):
path = Path(db_file_path) if isinstance(db_file_path, str) else db_file_path
engine = get_active_engine(db_file_path)
return ActiveProject(engine, project_name, root_path=path.parent)

def __init__(self, engine: Optional[Engine], project_name: str, root_path: Optional[Path] = None):
self._engine = engine
self._metadata = MetaData(bind=self._engine)
self._project_name = project_name
self._root_path = root_path

active_project = Table("active_project", self._metadata, autoload_with=self._engine)
stmt = select(active_project.c.project_hash).where(active_project.c.project_name == f"{self._project_name}")
with Session(self._engine) as sess:
res = sess.exec(
select(_P.project_hash, _P.project_ontology).where(_P.project_name == project_name).limit(1)
).first()

with self._engine.connect() as connection:
result = connection.execute(stmt).fetchone()
if res is None:
raise ValueError(f"Couldn't find project with name `{project_name}` in the DB.")

if result is not None:
self._project_hash = result[0]
else:
self._project_hash = None
self.project_hash, ontology = res
project_tuples = sess.exec(select(_T.name, _T.tag_hash).where(_T.project_hash == self.project_hash)).all()

with Session(engine) as sess:
self._existing_tags = {
tag.name: tag.tag_hash
for tag in sess.exec(
sqlmodel_select(ProjectTag).where(ProjectTag.project_hash == self._project_hash)
).all()
}
sess.commit()
# Assuming that there's just one prediction model
# FIXME: With multiple sets of model predictions, we should select the right UUID here
self._model_hash = sess.exec(
select(ProjectPrediction.prediction_hash)
.where(ProjectPrediction.project_hash == self.project_hash)
.limit(1)
).first()

def _execute_statement(self, stmt: Select) -> Any:
with self._engine.connect() as connection:
result = connection.execute(stmt).fetchone()

if result is not None:
return result[0]
else:
return None
self._existing_tags = dict(project_tuples)
self.ontology = OntologyStructure.from_dict(ontology) # type: ignore

# For backward compatibility
def get_ontology(self) -> OntologyStructure:
active_project = Table("active_project", self._metadata, autoload_with=self._engine)

stmt = select(active_project.c.project_ontology).where(active_project.c.project_hash == f"{self._project_hash}")
return self.ontology

def _join_path_statement(self, stmt, base_model: AnalyticsModel):
stmt = stmt.add_columns(ProjectDataUnitMetadata.data_uri, ProjectDataUnitMetadata.data_uri_is_video)
stmt = stmt.join(
ProjectDataUnitMetadata,
onclause=(base_model.du_hash == ProjectDataUnitMetadata.du_hash)
& (base_model.frame == ProjectDataUnitMetadata.frame),
).where(ProjectDataUnitMetadata.project_hash == self.project_hash)

def transform(df):
_root_path = self._root_path
if _root_path is None:
raise ValueError("Root path is not set. Provide it in the constructor or use `from_db_file`")

df["data_uri"] = df.data_uri.map(lambda p: url_to_file_path(p, project_dir=_root_path) if p else p)
return df

return stmt, transform

def _join_data_tags_statement(self, stmt, base_model: AnalyticsModel, group_by_cols=None):
stmt = stmt.add_columns(("[" + func.group_concat('"' + ProjectTag.name + '"', ", ") + "]").label("data_tags"))
stmt = (
stmt.outerjoin(
ProjectTaggedDataUnit,
onclause=(base_model.du_hash == ProjectTaggedDataUnit.du_hash)
& (base_model.frame == ProjectTaggedDataUnit.frame),
)
.outerjoin(
ProjectTag,
onclause=(
(ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash)
& (ProjectTaggedDataUnit.project_hash == self.project_hash)
),
)
.group_by(base_model.du_hash, base_model.frame, *(group_by_cols or []))
)

return OntologyStructure.from_dict(self._execute_statement(stmt))
def transform(df):
df["data_tags"] = df.data_tags.map(lambda x: eval(x) if x else [])
return df

def get_prediction_metrics(self) -> pd.DataFrame:
active_project_prediction = Table("active_project_prediction", self._metadata, autoload_with=self._engine)
stmt = select(active_project_prediction.c.prediction_hash).where(
active_project_prediction.c.project_hash == f"{self._project_hash}"
)
return stmt, transform

prediction_hash = self._execute_statement(stmt)
def get_prediction_metrics(self, include_data_uris: bool = False, include_data_tags: bool = False) -> pd.DataFrame:
"""
Returns a pandas data frame with all the prediction metrics.
active_project_prediction_analytics = Table(
"active_project_prediction_analytics", self._metadata, autoload_with=self._engine
)
Args:
include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file.
include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image.
Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\\/` in your tag names.
stmt = select(
[
active_project_prediction_analytics.c.du_hash,
active_project_prediction_analytics.c.feature_hash,
active_project_prediction_analytics.c.metric_area,
active_project_prediction_analytics.c.metric_area_relative,
active_project_prediction_analytics.c.metric_aspect_ratio,
active_project_prediction_analytics.c.metric_brightness,
active_project_prediction_analytics.c.metric_contrast,
active_project_prediction_analytics.c.metric_sharpness,
active_project_prediction_analytics.c.metric_red,
active_project_prediction_analytics.c.metric_green,
active_project_prediction_analytics.c.metric_blue,
active_project_prediction_analytics.c.metric_label_border_closeness,
active_project_prediction_analytics.c.metric_label_confidence,
text(
"""CASE
WHEN feature_hash == match_feature_hash THEN 1
ELSE 0
END AS true_positive
"""
),
]
).where(active_project_prediction_analytics.c.prediction_hash == prediction_hash)
"""
if self._model_hash is None:
raise ValueError(f"Project with name `{self._project_name}` does not have any model predictions")

with self._engine.begin() as conn:
df = pd.read_sql(stmt, conn)
with Session(self._engine) as sess:
P = ProjectPredictionAnalytics
stmt = select( # type: ignore
P.du_hash,
P.feature_hash,
P.metric_area,
P.metric_area_relative,
P.metric_aspect_ratio,
P.metric_brightness,
P.metric_contrast,
P.metric_sharpness,
P.metric_red,
P.metric_green,
P.metric_blue,
P.metric_label_border_closeness,
P.metric_label_confidence,
(P.feature_hash == P.match_feature_hash).cast(Integer).label("true_positive"), # type: ignore
).where(P.project_hash == self.project_hash, P.prediction_hash == self._model_hash)

transforms = []
if include_data_uris:
stmt, transform = self._join_path_statement(stmt, P)
transforms.append(transform)
if include_data_tags:
stmt, transform = self._join_data_tags_statement(stmt, P, group_by_cols=[P.object_hash])
transforms.append(transform)

df = pd.DataFrame(sess.exec(stmt).all())
df.columns = list(sess.exec(stmt).keys()) or df.columns # type: ignore

for transform in transforms:
df = transform(df)

return df

def get_images_metrics(self) -> pd.DataFrame:
active_project_analytics_data = Table(
"active_project_analytics_data", self._metadata, autoload_with=self._engine
)
stmt = select(active_project_analytics_data).where(
active_project_analytics_data.c.project_hash == f"{self._project_hash}"
)
def get_images_metrics(self, *, include_data_uris: bool = False, include_data_tags: bool = True) -> pd.DataFrame:
"""
Returns a pandas data frame with all the prediction metrics.
Args:
include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file.
include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image.
Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\\/` in your tag names.
"""
with Session(self._engine) as sess:
# hack to get all columns without the pydantic model
stmt = select(*[c for c in ProjectDataAnalytics.__table__.c][:3]).where( # type: ignore
ProjectDataAnalytics.project_hash == self.project_hash
)
transforms = []
if include_data_uris:
stmt, transform = self._join_path_statement(stmt, ProjectDataAnalytics)
transforms.append(transform)
if include_data_tags:
stmt, transform = self._join_data_tags_statement(stmt, ProjectDataAnalytics)
transforms.append(transform)
image_metrics = sess.exec(stmt).all()

df = pd.DataFrame(image_metrics)
df.columns = list(sess.execute(stmt).keys()) or df.columns # type: ignore

with self._engine.begin() as conn:
df = pd.read_sql(stmt, conn)
for transform in transforms:
df = transform(df)

return df

Expand All @@ -129,7 +208,7 @@ def get_or_add_tag(self, tag: str) -> UUID:
new_tag = ProjectTag(
tag_hash=tag_hash,
name=tag,
project_hash=self._project_hash,
project_hash=self.project_hash,
description="",
)
sess.add(new_tag)
Expand All @@ -146,7 +225,7 @@ def add_tag_to_data_units(
for du_item in du_items:
sess.add(
ProjectTaggedDataUnit(
project_hash=self._project_hash,
project_hash=self.project_hash,
du_hash=du_item.du_hash,
frame=du_item.frame,
tag_hash=tag_hash,
Expand Down

0 comments on commit b4f46ff

Please sign in to comment.