Skip to content

Commit

Permalink
feat: add support for including tags as well
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-encord committed Dec 6, 2023
1 parent 30a5ec2 commit 1c7309c
Showing 1 changed file with 67 additions and 21 deletions.
88 changes: 67 additions & 21 deletions src/encord_active/public/active_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pandas as pd
from encord.objects import OntologyStructure
from sqlalchemy import Integer
from sqlalchemy import Integer, func
from sqlalchemy.engine import Engine
from sqlmodel import Session, select

Expand Down Expand Up @@ -79,7 +79,53 @@ def __init__(self, engine: Optional[Engine], project_name: str, root_path: Optio
def get_ontology(self) -> OntologyStructure:
return self.ontology

def get_prediction_metrics(self, include_data_uris: bool = False) -> pd.DataFrame:
def _join_path_statement(self, stmt, base_model: AnalyticsModel):
stmt = stmt.add_columns(ProjectDataUnitMetadata.data_uri, ProjectDataUnitMetadata.data_uri_is_video)
stmt = stmt.join(
ProjectDataUnitMetadata,
onclause=(base_model.du_hash == ProjectDataUnitMetadata.du_hash)
& (base_model.frame == ProjectDataUnitMetadata.frame),
).where(ProjectDataUnitMetadata.project_hash == self.project_hash)

def transform(df):
if self._root_path is None:
raise ValueError(f"Root path is not set. Provide it in the constructor or use `from_db_file`")
df["data_uri"] = df.data_uri.map(partial(url_to_file_path, project_dir=self._root_path))
return df

return stmt, transform

def _join_data_tags_statement(self, stmt, base_model: AnalyticsModel, group_by_cols=None):
stmt = stmt.add_columns(("[" + func.group_concat('"' + ProjectTag.name + '"', ", ") + "]").label("data_tags"))
stmt = (
stmt.join(
ProjectTaggedDataUnit,
onclause=(base_model.du_hash == ProjectTaggedDataUnit.du_hash)
& (base_model.frame == ProjectTaggedDataUnit.frame),
)
.join(ProjectTag, onclause=ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash)
.where(
ProjectTag.project_hash == self.project_hash, ProjectTaggedDataUnit.project_hash == self.project_hash
)
.group_by(base_model.du_hash, base_model.frame, *(group_by_cols or []))
)

def transform(df):
df["data_tags"] = df.data_tags.map(eval)
return df

return stmt, transform

def get_prediction_metrics(self, include_data_uris: bool = False, include_data_tags: bool = False) -> pd.DataFrame:
"""
Returns a pandas data frame with all the prediction metrics.
Args:
include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file.
include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image.
Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\/` in your tag names.
"""
if self._model_hash is None:
raise ValueError(f"Project with name {self._project_name} does not have any model predictions")

Expand All @@ -102,9 +148,13 @@ def get_prediction_metrics(self, include_data_uris: bool = False) -> pd.DataFram
(P.feature_hash == P.match_feature_hash).cast(Integer).label("true_positive"), # type: ignore
).where(P.project_hash == self.project_hash, P.prediction_hash == self._model_hash)

transform = None
transforms = []
if include_data_uris:
stmt, transform = self._join_path_statement(stmt, P)
transforms.append(transform)
if include_data_tags:
stmt, transform = self._join_data_tags_statement(stmt, P, group_by_cols=[P.object_hash])
transforms.append(transform)

df = pd.DataFrame(sess.exec(stmt).all())
df.columns = list(sess.exec(stmt).keys())
Expand All @@ -114,35 +164,31 @@ def get_prediction_metrics(self, include_data_uris: bool = False) -> pd.DataFram

return df

def _join_path_statement(self, stmt, base_model: AnalyticsModel):
stmt = stmt.add_columns(ProjectDataUnitMetadata.data_uri, ProjectDataUnitMetadata.data_uri_is_video)
stmt = stmt.join(
ProjectDataUnitMetadata,
onclause=(base_model.du_hash == ProjectDataUnitMetadata.du_hash)
& (base_model.frame == ProjectDataUnitMetadata.frame),
).where(ProjectDataUnitMetadata.project_hash == self.project_hash)

def transform(df):
if self._root_path is None:
raise ValueError(f"Root path is not set. Provide it in the constructor or use `from_db_file`")
df["data_uri"] = df.data_uri.map(partial(url_to_file_path, project_dir=self._root_path))
return df
def get_images_metrics(self, *, include_data_uris: bool = False, include_data_tags: bool = True) -> pd.DataFrame:
"""
Returns a pandas data frame with all the prediction metrics.
return stmt, transform

def get_images_metrics(self, *, include_data_uris: bool = False) -> pd.DataFrame:
Args:
include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file.
include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image.
Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\/` in your tag names.
"""
with Session(self._engine) as sess:
stmt = select(*[c for c in ProjectDataAnalytics.__table__.c]).where( # type: ignore -- hack to get all columns without the pydantic model
ProjectDataAnalytics.project_hash == self.project_hash
)
transform = None
transforms = []
if include_data_uris:
stmt, transform = self._join_path_statement(stmt, ProjectDataAnalytics)
transforms.append(transform)
if include_data_tags:
stmt, transform = self._join_data_tags_statement(stmt, ProjectDataAnalytics)
transforms.append(transform)
image_metrics = sess.exec(stmt).all()
df = pd.DataFrame(image_metrics)
df.columns = list(sess.execute(stmt).keys())

if transform is not None:
for transform in transforms:
df = transform(df)

return df
Expand Down

0 comments on commit 1c7309c

Please sign in to comment.