Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/nrtk_explorer/app/features/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def add_images(self, dataset_id_to_image: IdToImage):
self.transformed_features.update(id_to_feature)
self.emit_update()

@change("dataset_ids")
def on_dataset_ids(self, **kwargs):
@change("user_selected_ids")
def on_user_selected_ids(self, **kwargs):
self.transformed_features = {
id: features
for id, features in self.transformed_features.items()
if image_id_to_dataset_id(id) in self.server.state.dataset_ids
if image_id_to_dataset_id(id) in self.server.state.user_selected_ids
}
self.emit_update()

Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
# Local initialization if standalone
self.is_standalone_app = self.server.root_server == self.server
if self.is_standalone_app and datasets:
self.state.dataset_ids = []
self.state.user_selected_ids = []
self.state.current_dataset = datasets[0]
self.context.dataset = get_dataset(self.state.current_dataset)

Expand Down Expand Up @@ -115,7 +115,7 @@ def on_server_ready(self, *args, **kwargs):
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
self.save_embedding_params()
self.update_points()
self.state.change("dataset_ids")(self.update_points)
self.state.change("user_selected_ids")(self.update_points)
self.ctrl.apply_transform.add(self.clear_points_transformations)
self.ctrl.apply_transform.add(self.transformed_images.clear)
self.state.change("transform_enabled_switch")(self.update_points_transformations_state)
Expand All @@ -127,7 +127,7 @@ def on_feature_extraction_model_change(self, **kwargs):
self.transformed_images.set_extractor(self.extractor)

def compute_points(self, fit_features, features):
if len(features) == 0:
if len(features) <= 1:
# reduce will fail if no features
return []

Expand Down Expand Up @@ -171,14 +171,14 @@ def update_points_transformations_state(self, **kwargs):

def compute_source_points(self):
images = (
self.images.get_image_without_cache_eviction(id) for id in self.state.dataset_ids
self.images.get_image_without_cache_eviction(id) for id in self.state.user_selected_ids
)
self.features = self.extractor.extract(images)

points = self.compute_points(self.features, self.features)

self.state.points_sources = {
id: point for id, point in zip(self.state.dataset_ids, points)
id: point for id, point in zip(self.state.user_selected_ids, points)
}

self.state.camera_position = []
Expand Down Expand Up @@ -243,15 +243,15 @@ def update_transformed_points(self, id_to_features):
self.update_points_transformations_state()

def on_scatter_select(self, image_ids):
self.state.user_selected_ids = image_ids or self.state.dataset_ids
self.state.user_selected_ids = image_ids or self.state.user_selected_ids

def on_move(self, camera_position):
self.state.camera_position = camera_position

def get_dataset_id_index(self, point_index):
if point_index < len(self.state.dataset_ids):
if point_index < len(self.state.user_selected_ids):
return point_index
return point_index - len(self.state.dataset_ids)
return point_index - len(self.state.user_selected_ids)

def on_point_hover(self, event):
self.state.highlighted_image = event
Expand Down
Loading