diff --git a/src/encord_active/public/active_project.py b/src/encord_active/public/active_project.py index 753e63278..ea9787224 100644 --- a/src/encord_active/public/active_project.py +++ b/src/encord_active/public/active_project.py @@ -6,7 +6,7 @@ import pandas as pd from encord.objects import OntologyStructure -from sqlalchemy import Integer +from sqlalchemy import Integer, func from sqlalchemy.engine import Engine from sqlmodel import Session, select @@ -79,7 +79,53 @@ def __init__(self, engine: Optional[Engine], project_name: str, root_path: Optio def get_ontology(self) -> OntologyStructure: return self.ontology - def get_prediction_metrics(self, include_data_uris: bool = False) -> pd.DataFrame: + def _join_path_statement(self, stmt, base_model: AnalyticsModel): + stmt = stmt.add_columns(ProjectDataUnitMetadata.data_uri, ProjectDataUnitMetadata.data_uri_is_video) + stmt = stmt.join( + ProjectDataUnitMetadata, + onclause=(base_model.du_hash == ProjectDataUnitMetadata.du_hash) + & (base_model.frame == ProjectDataUnitMetadata.frame), + ).where(ProjectDataUnitMetadata.project_hash == self.project_hash) + + def transform(df): + if self._root_path is None: + raise ValueError(f"Root path is not set. Provide it in the constructor or use `from_db_file`") + df["data_uri"] = df.data_uri.map(partial(url_to_file_path, project_dir=self._root_path)) + return df + + return stmt, transform + + def _join_data_tags_statement(self, stmt, base_model: AnalyticsModel, group_by_cols=None): + stmt = stmt.add_columns(("[" + func.group_concat('"' + ProjectTag.name + '"', ", ") + "]").label("data_tags")) + stmt = ( + stmt.join( + ProjectTaggedDataUnit, + onclause=(base_model.du_hash == ProjectTaggedDataUnit.du_hash) + & (base_model.frame == ProjectTaggedDataUnit.frame), + ) + .join(ProjectTag, onclause=ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash) + .where( + ProjectTag.project_hash == self.project_hash, ProjectTaggedDataUnit.project_hash == self.project_hash + ) + .group_by(base_model.du_hash, base_model.frame, *(group_by_cols or [])) + ) + + def transform(df): + df["data_tags"] = df.data_tags.map(eval) + return df + + return stmt, transform + + def get_prediction_metrics(self, include_data_uris: bool = False, include_data_tags: bool = False) -> pd.DataFrame: + """ + Returns a pandas data frame with all the prediction metrics. + + Args: + include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file. + include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image. + Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\\/` in your tag names. + + """ if self._model_hash is None: raise ValueError(f"Project with name {self._project_name} does not have any model predictions") @@ -102,47 +148,47 @@ def get_prediction_metrics(self, include_data_uris: bool = False) -> pd.DataFram (P.feature_hash == P.match_feature_hash).cast(Integer).label("true_positive"), # type: ignore ).where(P.project_hash == self.project_hash, P.prediction_hash == self._model_hash) - transform = None + transforms = [] if include_data_uris: stmt, transform = self._join_path_statement(stmt, P) + transforms.append(transform) + if include_data_tags: + stmt, transform = self._join_data_tags_statement(stmt, P, group_by_cols=[P.object_hash]) + transforms.append(transform) df = pd.DataFrame(sess.exec(stmt).all()) df.columns = list(sess.exec(stmt).keys()) - if transform is not None: + for transform in transforms: df = transform(df) return df - def _join_path_statement(self, stmt, base_model: AnalyticsModel): - stmt = stmt.add_columns(ProjectDataUnitMetadata.data_uri, ProjectDataUnitMetadata.data_uri_is_video) - stmt = stmt.join( - ProjectDataUnitMetadata, - onclause=(base_model.du_hash == ProjectDataUnitMetadata.du_hash) - & (base_model.frame == ProjectDataUnitMetadata.frame), - ).where(ProjectDataUnitMetadata.project_hash == self.project_hash) - - def transform(df): - if self._root_path is None: - raise ValueError(f"Root path is not set. Provide it in the constructor or use `from_db_file`") - df["data_uri"] = df.data_uri.map(partial(url_to_file_path, project_dir=self._root_path)) - return df - - return stmt, transform + def get_images_metrics(self, *, include_data_uris: bool = False, include_data_tags: bool = True) -> pd.DataFrame: + """ + Returns a pandas data frame with all the prediction metrics. - def get_images_metrics(self, *, include_data_uris: bool = False) -> pd.DataFrame: + Args: + include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file. + include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image. + Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\\/` in your tag names. + """ with Session(self._engine) as sess: stmt = select(*[c for c in ProjectDataAnalytics.__table__.c]).where( # type: ignore -- hack to get all columns without the pydantic model ProjectDataAnalytics.project_hash == self.project_hash ) - transform = None + transforms = [] if include_data_uris: stmt, transform = self._join_path_statement(stmt, ProjectDataAnalytics) + transforms.append(transform) + if include_data_tags: + stmt, transform = self._join_data_tags_statement(stmt, ProjectDataAnalytics) + transforms.append(transform) image_metrics = sess.exec(stmt).all() df = pd.DataFrame(image_metrics) df.columns = list(sess.execute(stmt).keys()) - if transform is not None: + for transform in transforms: df = transform(df) return df