Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: active project export data #680

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 164 additions & 85 deletions src/encord_active/public/active_project.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,201 @@
from dataclasses import dataclass
from typing import Any
from pathlib import Path
from typing import Optional, Type, Union
from uuid import UUID, uuid4

import pandas as pd
from encord.objects import OntologyStructure
from sqlalchemy import MetaData, Table, create_engine, select, text
from sqlalchemy import Integer, func
from sqlalchemy.engine import Engine
from sqlalchemy.sql import Select
from sqlmodel import Session
from sqlmodel import select as sqlmodel_select
from sqlmodel import Session, select

from encord_active.db.models import ProjectTag, ProjectTaggedDataUnit
from encord_active.db.models import (
Project,
ProjectAnnotationAnalytics,
ProjectDataAnalytics,
ProjectDataUnitMetadata,
ProjectPrediction,
ProjectPredictionAnalytics,
ProjectTag,
ProjectTaggedDataUnit,
get_engine,
)
from encord_active.lib.common.data_utils import url_to_file_path

_P = Project
_T = ProjectTag


@dataclass
class DataUnitItem:
du_hash: str
du_hash: UUID
frame: int


def get_active_engine(path_to_db: str) -> Engine:
return create_engine(f"sqlite:///{path_to_db}")
AnalyticsModel = Union[Type[ProjectDataAnalytics], Type[ProjectPredictionAnalytics], Type[ProjectAnnotationAnalytics]]


def get_active_engine(path_to_db: Union[str, Path]) -> Engine:
path = Path(path_to_db) if isinstance(path_to_db, str) else path_to_db
return get_engine(path, use_alembic=False)


class ActiveProject:
def __init__(self, engine: Engine, project_name: str):
@classmethod
def from_db_file(cls, db_file_path: Union[str, Path], project_name: str):
path = Path(db_file_path) if isinstance(db_file_path, str) else db_file_path
engine = get_active_engine(db_file_path)
return ActiveProject(engine, project_name, root_path=path.parent)

def __init__(self, engine: Optional[Engine], project_name: str, root_path: Optional[Path] = None):
self._engine = engine
self._metadata = MetaData(bind=self._engine)
self._project_name = project_name
self._root_path = root_path

active_project = Table("active_project", self._metadata, autoload_with=self._engine)
stmt = select(active_project.c.project_hash).where(active_project.c.project_name == f"{self._project_name}")
with Session(self._engine) as sess:
res = sess.exec(
select(_P.project_hash, _P.project_ontology).where(_P.project_name == project_name).limit(1)
).first()

with self._engine.connect() as connection:
result = connection.execute(stmt).fetchone()
if res is None:
raise ValueError(f"Couldn't find project with name `{project_name}` in the DB.")

if result is not None:
self._project_hash = result[0]
else:
self._project_hash = None
self.project_hash, ontology = res
project_tuples = sess.exec(select(_T.name, _T.tag_hash).where(_T.project_hash == self.project_hash)).all()

with Session(engine) as sess:
self._existing_tags = {
tag.name: tag.tag_hash
for tag in sess.exec(
sqlmodel_select(ProjectTag).where(ProjectTag.project_hash == self._project_hash)
).all()
}
sess.commit()
# Assuming that there's just one prediction model
# FIXME: With multiple sets of model predictions, we should select the right UUID here
self._model_hash = sess.exec(
select(ProjectPrediction.prediction_hash)
.where(ProjectPrediction.project_hash == self.project_hash)
.limit(1)
).first()

def _execute_statement(self, stmt: Select) -> Any:
with self._engine.connect() as connection:
result = connection.execute(stmt).fetchone()

if result is not None:
return result[0]
else:
return None
self._existing_tags = dict(project_tuples)
self.ontology = OntologyStructure.from_dict(ontology) # type: ignore

# For backward compatibility
def get_ontology(self) -> OntologyStructure:
active_project = Table("active_project", self._metadata, autoload_with=self._engine)

stmt = select(active_project.c.project_ontology).where(active_project.c.project_hash == f"{self._project_hash}")
return self.ontology

def _join_path_statement(self, stmt, base_model: AnalyticsModel):
stmt = stmt.add_columns(ProjectDataUnitMetadata.data_uri, ProjectDataUnitMetadata.data_uri_is_video)
stmt = stmt.join(
ProjectDataUnitMetadata,
onclause=(base_model.du_hash == ProjectDataUnitMetadata.du_hash)
& (base_model.frame == ProjectDataUnitMetadata.frame),
).where(ProjectDataUnitMetadata.project_hash == self.project_hash)

def transform(df):
_root_path = self._root_path
if _root_path is None:
raise ValueError("Root path is not set. Provide it in the constructor or use `from_db_file`")

df["data_uri"] = df.data_uri.map(lambda p: url_to_file_path(p, project_dir=_root_path) if p else p)
return df

return stmt, transform

def _join_data_tags_statement(self, stmt, base_model: AnalyticsModel, group_by_cols=None):
stmt = stmt.add_columns(("[" + func.group_concat('"' + ProjectTag.name + '"', ", ") + "]").label("data_tags"))
stmt = (
stmt.outerjoin(
ProjectTaggedDataUnit,
onclause=(base_model.du_hash == ProjectTaggedDataUnit.du_hash)
& (base_model.frame == ProjectTaggedDataUnit.frame),
)
.outerjoin(
ProjectTag,
onclause=(
(ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash)
& (ProjectTaggedDataUnit.project_hash == self.project_hash)
),
)
.group_by(base_model.du_hash, base_model.frame, *(group_by_cols or []))
)

return OntologyStructure.from_dict(self._execute_statement(stmt))
def transform(df):
df["data_tags"] = df.data_tags.map(lambda x: eval(x) if x else [])
return df

def get_prediction_metrics(self) -> pd.DataFrame:
active_project_prediction = Table("active_project_prediction", self._metadata, autoload_with=self._engine)
stmt = select(active_project_prediction.c.prediction_hash).where(
active_project_prediction.c.project_hash == f"{self._project_hash}"
)
return stmt, transform

prediction_hash = self._execute_statement(stmt)
def get_prediction_metrics(self, include_data_uris: bool = False, include_data_tags: bool = False) -> pd.DataFrame:
"""
Returns a pandas data frame with all the prediction metrics.

active_project_prediction_analytics = Table(
"active_project_prediction_analytics", self._metadata, autoload_with=self._engine
)
Args:
include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file.
include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image.
Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\\/` in your tag names.

stmt = select(
[
active_project_prediction_analytics.c.du_hash,
active_project_prediction_analytics.c.feature_hash,
active_project_prediction_analytics.c.metric_area,
active_project_prediction_analytics.c.metric_area_relative,
active_project_prediction_analytics.c.metric_aspect_ratio,
active_project_prediction_analytics.c.metric_brightness,
active_project_prediction_analytics.c.metric_contrast,
active_project_prediction_analytics.c.metric_sharpness,
active_project_prediction_analytics.c.metric_red,
active_project_prediction_analytics.c.metric_green,
active_project_prediction_analytics.c.metric_blue,
active_project_prediction_analytics.c.metric_label_border_closeness,
active_project_prediction_analytics.c.metric_label_confidence,
text(
"""CASE
WHEN feature_hash == match_feature_hash THEN 1
ELSE 0
END AS true_positive
"""
),
]
).where(active_project_prediction_analytics.c.prediction_hash == prediction_hash)
"""
if self._model_hash is None:
raise ValueError(f"Project with name `{self._project_name}` does not have any model predictions")

with self._engine.begin() as conn:
df = pd.read_sql(stmt, conn)
with Session(self._engine) as sess:
P = ProjectPredictionAnalytics
stmt = select( # type: ignore
P.du_hash,
P.feature_hash,
P.metric_area,
P.metric_area_relative,
P.metric_aspect_ratio,
P.metric_brightness,
P.metric_contrast,
P.metric_sharpness,
P.metric_red,
P.metric_green,
P.metric_blue,
P.metric_label_border_closeness,
P.metric_label_confidence,
(P.feature_hash == P.match_feature_hash).cast(Integer).label("true_positive"), # type: ignore
).where(P.project_hash == self.project_hash, P.prediction_hash == self._model_hash)

transforms = []
if include_data_uris:
stmt, transform = self._join_path_statement(stmt, P)
transforms.append(transform)
if include_data_tags:
stmt, transform = self._join_data_tags_statement(stmt, P, group_by_cols=[P.object_hash])
transforms.append(transform)

df = pd.DataFrame(sess.exec(stmt).all())
df.columns = list(sess.exec(stmt).keys()) or df.columns # type: ignore

for transform in transforms:
df = transform(df)

return df

def get_images_metrics(self) -> pd.DataFrame:
active_project_analytics_data = Table(
"active_project_analytics_data", self._metadata, autoload_with=self._engine
)
stmt = select(active_project_analytics_data).where(
active_project_analytics_data.c.project_hash == f"{self._project_hash}"
)
def get_images_metrics(self, *, include_data_uris: bool = False, include_data_tags: bool = True) -> pd.DataFrame:
"""
Returns a pandas data frame with all the prediction metrics.

Args:
include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file.
include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image.
Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\\/` in your tag names.
"""
with Session(self._engine) as sess:
# hack to get all columns without the pydantic model
stmt = select(*[c for c in ProjectDataAnalytics.__table__.c][:3]).where( # type: ignore
ProjectDataAnalytics.project_hash == self.project_hash
)
transforms = []
if include_data_uris:
stmt, transform = self._join_path_statement(stmt, ProjectDataAnalytics)
transforms.append(transform)
if include_data_tags:
stmt, transform = self._join_data_tags_statement(stmt, ProjectDataAnalytics)
transforms.append(transform)
image_metrics = sess.exec(stmt).all()

df = pd.DataFrame(image_metrics)
df.columns = list(sess.execute(stmt).keys()) or df.columns # type: ignore

with self._engine.begin() as conn:
df = pd.read_sql(stmt, conn)
for transform in transforms:
df = transform(df)

return df

Expand All @@ -129,7 +208,7 @@ def get_or_add_tag(self, tag: str) -> UUID:
new_tag = ProjectTag(
tag_hash=tag_hash,
name=tag,
project_hash=self._project_hash,
project_hash=self.project_hash,
description="",
)
sess.add(new_tag)
Expand All @@ -146,7 +225,7 @@ def add_tag_to_data_units(
for du_item in du_items:
sess.add(
ProjectTaggedDataUnit(
project_hash=self._project_hash,
project_hash=self.project_hash,
du_hash=du_item.du_hash,
frame=du_item.frame,
tag_hash=tag_hash,
Expand Down