Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,31 @@ classifiers = [
]

dependencies = [
"accelerate",
"nrtk>=0.4.2",
"numpy",
"Pillow",
"scikit-learn==1.5.0",
"accelerate",
"smqtk_image_io",
"tabulate",
"transformers",
"timm>=1.0.3",
"torch",
"torchvision",
"trame",
"trame-client>=2.15.0",
"trame-quasar",
"trame-server>=2.15.0",
"transformers",
"umap-learn",
"tabulate",
]

[project.optional-dependencies]
dev = [
"black",
"flake8",
"mypy",
"pytest",
"tabulate",
"mypy",
]

package = [
Expand Down
12 changes: 0 additions & 12 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from nrtk_explorer.library import dimension_reducers
from nrtk_explorer.library import images_manager
from nrtk_explorer.app.applet import Applet
from nrtk_explorer.app.image_meta import update_image_meta
import nrtk_explorer.test_data

import asyncio
Expand Down Expand Up @@ -162,17 +161,6 @@ def on_run_transformations(self, transformed_image_ids):
**args,
)

for index, image_id in enumerate(transformed_image_ids):
transform_point = self.state.points_transformations[index]
original_image_point_index = self.state.user_selected_points_indices[index]
original_point = self.state.points_sources[original_image_point_index]
distance = sum(
(transform_point[i] - original_point[i]) ** 2 for i in range(len(original_point))
)
distance = distance**0.5
dataset_id = image_id.split("_")[-1]
update_image_meta(self.state, dataset_id, {"distance": distance})

def set_on_select(self, fn):
self._on_select_fn = fn

Expand Down
140 changes: 106 additions & 34 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,18 @@
from nrtk_explorer.library import images_manager, object_detector
from nrtk_explorer.app import ui
from nrtk_explorer.app.applet import Applet
from nrtk.impls.score_detections.coco_scorer import COCOScorer
from nrtk_explorer.app.parameters import ParametersApp
from nrtk_explorer.app.image_meta import update_image_meta, delete_image_meta
from nrtk_explorer.app.image_meta import (
update_image_meta,
delete_image_meta,
)
from nrtk_explorer.library.coco_utils import (
convert_from_ground_truth_to_first_arg,
convert_from_ground_truth_to_second_arg,
convert_from_predictions_to_second_arg,
convert_from_predictions_to_first_arg,
)
import nrtk_explorer.test_data

import json
Expand Down Expand Up @@ -87,6 +97,7 @@ def __init__(self, server):

self.state.source_image_ids = []
self.state.transformed_image_ids = []

# Image kinds (Original, Transformed, ...) to display per dataset image in ImageList
self.state.image_kinds = [
{"image_ids_list": "source_image_ids", "readable": "Original"},
Expand All @@ -102,13 +113,14 @@ def __init__(self, server):

self.server.controller.add("on_server_ready")(self.on_server_ready)
self._on_hover_fn = None
self.detector = object_detector.ObjectDetector(model_name="hustvl/yolos-tiny")
self.detector = object_detector.ObjectDetector(model_name="facebook/detr-resnet-50")

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
self.state.change("current_dataset")(self.on_current_dataset_change)
self.state.change("current_num_elements")(self.on_current_num_elements_change)
self.state.change("enabled_model_images")(self.on_detection_model_change)

self.on_current_dataset_change(self.state.current_dataset)

Expand Down Expand Up @@ -138,24 +150,66 @@ def on_apply_transform(self, *args, **kwargs):

self.state[transformed_image_id] = images_manager.convert_to_base64(transformed_img)

self.state.transformed_image_ids = transformed_image_ids
if len(self.state.source_image_ids) > 0:
self.state.hovered_id = ""

self.state.transformed_image_ids = transformed_image_ids
self.compute_annotations(transformed_image_ids)

# Only invoke callbacks when we transform images
if len(transformed_image_ids) > 0:
with open(self.state.current_dataset) as f:
dataset = json.load(f)

# Erase current annotations
ids = [int(id_.split("_")[-1]) for id_ in self.state.source_image_ids]
for ann in dataset["annotations"]:
if ann["image_id"] in ids:
transformed_id = f"transformed_img_{ann['image_id']}"
if transformed_id in self.context["annotations"]:
del self.context["annotations"][transformed_id]

# Compute new annotations
if "transformation" in self.state.enabled_model_images:
predictions = convert_from_predictions_to_second_arg(
self.compute_annotations(transformed_image_ids)
)
else:
for ann in dataset["annotations"]:
if ann["image_id"] in ids:
image_annotations = self.context["annotations"].setdefault(
f"transformed_img_{ann['image_id']}", []
)
image_annotations.append(ann)

predictions = convert_from_ground_truth_to_second_arg(
[
self.context["annotations"][id_]
for id_ in self.state.transformed_image_ids
],
dataset,
)

self.compute_score(
self.state.source_image_ids, self.predictions_source_images, predictions
)

self.update_model_result(transformed_image_ids)

# Only invoke callbacks when we transform images
self.on_transform(transformed_image_ids)

def compute_score(self, ids, predictions_source_images, predictions_trans_images):
"""Compute the score for the given image ids using the COCO scorer."""
score_output = COCOScorer(self.state.current_dataset).score(
predictions_source_images, predictions_trans_images
)

for image_id, score in zip(ids, score_output):
update_image_meta(self.state, image_id.split("_")[-1], {"distance": score})

def compute_annotations(self, ids):
"""Compute annotations for the given image ids using the object detector model."""
if len(ids) == 0:
return

for id_ in ids:
self.context["annotations"][id_] = []

predictions = self.detector.eval(paths=ids, content=self.context.image_objects)

for id_, annotations in predictions:
Expand All @@ -180,19 +234,50 @@ def compute_annotations(self, ids):
}
)

self.update_model_result(ids, self.state.feature_extraction_model)
return predictions

def on_current_num_elements_change(self, current_num_elements, **kwargs):
with open(self.state.current_dataset) as f:
dataset = json.load(f)
ids = [img["id"] for img in dataset["images"]]
return self.set_source_images(ids[:current_num_elements])

def on_detection_model_change(self, enabled_model_images, **kwargs):
"""Update the model result when the detection model changes."""
self.compute_predictions_source_images(self.state.source_image_ids)
self.update_model_result(self.state.source_image_ids)
self.on_apply_transform()

def compute_predictions_source_images(self, ids):
"""Compute the predictions for the source images."""
if len(ids) > 0:
with open(self.state.current_dataset) as f:
dataset = json.load(f)
for id_ in ids:
del self.context["annotations"][id_]

if "source" in self.state.enabled_model_images:
self.predictions_source_images = convert_from_predictions_to_first_arg(
self.compute_annotations(ids),
dataset,
ids,
)
else:
for annotation in dataset["annotations"]:
image_id = f"img_{annotation['image_id']}"
image_annotations = self.context["annotations"].setdefault(image_id, [])
image_annotations.append(annotation)

self.predictions_source_images = convert_from_ground_truth_to_first_arg(
[self.context["annotations"][id_] for id_ in ids]
)

def _update_images(self, selected_ids):
source_image_ids = []

current_dir = os.path.dirname(self.state.current_dataset)

dataset = None
with open(self.state.current_dataset) as f:
dataset = json.load(f)

Expand Down Expand Up @@ -223,8 +308,9 @@ def _update_images(self, selected_ids):
self.state.hovered_id = ""

self.state.source_image_ids = source_image_ids
self.compute_annotations(source_image_ids)
self.update_model_result(self.state.source_image_ids, self.state.feature_extraction_model)
self.compute_predictions_source_images(self.state.source_image_ids)

self.update_model_result(source_image_ids)
self.on_apply_transform()

async def _set_source_images(self, selected_ids):
Expand Down Expand Up @@ -296,7 +382,11 @@ def on_current_dataset_change(self, current_dataset, **kwargs):

self.state.annotation_categories = categories

self.context["annotations"] = {}
if "source" not in self.state.enabled_model_images:
for annotation in dataset["annotations"]:
image_id = f"img_{annotation['image_id']}"
image_annotations = self.context["annotations"].setdefault(image_id, [])
image_annotations.append(annotation)

self.context.image_id_to_index = {}
for i, image in enumerate(dataset["images"]):
Expand All @@ -308,12 +398,10 @@ def on_current_dataset_change(self, current_dataset, **kwargs):
def on_feature_extraction_model_change(self, **kwargs):
logger.debug(f">>> on_feature_extraction_model_change change {self.state}")

feature_extraction_model = self.state.feature_extraction_model
self.update_model_result(self.state.source_image_ids)
self.update_model_result(self.state.transformed_image_ids)

self.update_model_result(self.state.source_image_ids, feature_extraction_model)
self.update_model_result(self.state.transformed_image_ids, feature_extraction_model)

def update_model_result(self, image_ids, feature_extraction_model):
def update_model_result(self, image_ids):
for image_id in image_ids:
result_id = image_id_to_result(image_id)
self.state[result_id] = self.context["annotations"].get(image_id, [])
Expand All @@ -333,22 +421,6 @@ def on_hover(self, hover_event):
def settings_widget(self):
with html.Div(trame_server=self.server):
with html.Div(classes="col"):
quasar.QSelect(
label="Object detection Model",
v_model=("object_detection_model", "facebook/detr-resnet-50"),
options=(
[
{
"label": "facebook/detr-resnet-50",
"value": "facebook/detr-resnet-50",
},
],
),
filled=True,
emit_value=True,
map_options=True,
)

self._parameters_app.transform_select_ui()

with html.Div(
Expand Down
2 changes: 1 addition & 1 deletion src/nrtk_explorer/app/ui/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, hover_fn=None, **kwargs):
{ name: 'id', label: 'ID', field: 'id', sortable: true },
{ name: 'original', label: 'Original Image', field: 'original' },
{ name: 'transformed', label: 'Transformed Image', field: 'transformed' },
{ name: 'distance', label: 'Transformed Embedding Distance', field: 'distance', sortable: true },
{ name: 'distance', label: 'Annotations similarity score', field: 'distance', sortable: true },
{ name: 'width', label: 'Width', field: 'width', sortable: true },
{ name: 'height', label: 'Height', field: 'height', sortable: true },
]""",
Expand Down
38 changes: 38 additions & 0 deletions src/nrtk_explorer/app/ui/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,44 @@ def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transf
with embeddings_actions_slot:
embeddings_app.compute_ui()

(annotations_title_slot, annotations_content_slot, _) = ui.card("collapse_annotations")

with annotations_title_slot:
html.Span("Annotations settings", classes="text-h6")

with annotations_content_slot:
quasar.QSelect(
label="Object detection Model",
v_model=("object_detection_model", "facebook/detr-resnet-50"),
options=(
[
{
"label": "facebook/detr-resnet-50",
"value": "facebook/detr-resnet-50",
},
],
),
filled=True,
emit_value=True,
map_options=True,
)
quasar.QOptionGroup(
v_model=("enabled_model_images", ["transformation"]),
options=(
[
{
"label": "Source images",
"value": "source",
},
{
"label": "Transformation images",
"value": "transformation",
},
],
),
type="toggle",
)

filter_title_slot, filter_content_slot, filter_actions_slot = ui.card("collapse_filter")

with filter_title_slot:
Expand Down
Loading