Skip to content

Commit

Permalink
more robust initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Jan 10, 2025
1 parent 5e45561 commit d5fc685
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
15 changes: 15 additions & 0 deletions fiftyone/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ def load_and_cache_dataset(name):
return dataset


def cache_dataset(dataset):
"""Caches the given dataset.
This method ensures that subsequent calls to
:func:`fiftyone.core.dataset.load_dataset` in async calls will return this
dataset singleton.
See :meth:`load_and_cache_dataset` for additional details.
Args:
dataset: a :class:`fiftyone.core.dataset.Dataset`
"""
_cache[dataset.name] = dataset


def change_sample_tags(sample_collection, changes):
"""Applies the changes to tags to all samples of the collection, if
necessary.
Expand Down
43 changes: 30 additions & 13 deletions plugins/panels/model_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,6 @@ def get_confusion_matrices(self, results):
"lc_colorscale": lc_colorscale,
}

def get_mask_targets(self, dataset, gt_field):
mask_targets = dataset.mask_targets.get(gt_field, None)
if mask_targets:
return mask_targets

if dataset.default_mask_targets:
return dataset.default_mask_targets

return None

def load_evaluation(self, ctx):
view_state = ctx.panel.get_state("view") or {}
eval_key = view_state.get("key")
Expand All @@ -353,8 +343,8 @@ def load_evaluation(self, ctx):
mask_targets = None

if evaluation_type == "segmentation":
mask_targets = self.get_mask_targets(ctx.dataset, gt_field)
_init_segmentation_results(results, mask_targets)
mask_targets = _get_mask_targets(ctx.dataset, gt_field)
_init_segmentation_results(ctx.dataset, results, gt_field)

metrics = results.metrics()
per_class_metrics = self.get_per_class_metrics(info, results)
Expand Down Expand Up @@ -591,6 +581,7 @@ def load_view(self, ctx):
)
elif info.config.type == "segmentation":
results = ctx.dataset.load_evaluation_results(eval_key)
_init_segmentation_results(ctx.dataset, results, gt_field)
if results.ytrue_ids is None or results.ypred_ids is None:
# Legacy format segmentations
return
Expand All @@ -600,6 +591,7 @@ def load_view(self, ctx):
gt_field2 = gt_field

results2 = ctx.dataset.load_evaluation_results(eval_key2)
_init_segmentation_results(ctx.dataset, results2, gt_field2)
if results2.ytrue_ids is None or results2.ypred_ids is None:
# Legacy format segmentations
return
Expand Down Expand Up @@ -681,11 +673,35 @@ def render(self, ctx):
)


def _init_segmentation_results(results, mask_targets):
def _get_mask_targets(dataset, gt_field):
mask_targets = dataset.mask_targets.get(gt_field, None)
if mask_targets:
return mask_targets

if dataset.default_mask_targets:
return dataset.default_mask_targets

return None


def _init_segmentation_results(dataset, results, gt_field):
if results.ytrue_ids is None or results.ypred_ids is None:
# Legacy format segmentations
return

if getattr(results, "_classes_map", None):
# Already initialized
return

#
# Ensure the dataset singleton is cached so that subsequent callbacks on
# this panel will use the same `dataset` and hence `results`
#

import fiftyone.server.utils as fosu

fosu.cache_dataset(dataset)

#
# `results.classes` and App callbacks could contain any of the
# following:
Expand All @@ -698,6 +714,7 @@ def _init_segmentation_results(results, mask_targets):
#
classes_map = {c: i for i, c in enumerate(results.classes)}

mask_targets = _get_mask_targets(dataset, gt_field)
if mask_targets is not None:
# `str()` handles cases 1 and 2, and `.get(c, c)` handles case 3
mask_targets = {str(k): v for k, v in mask_targets.items()}
Expand Down

0 comments on commit d5fc685

Please sign in to comment.