diff --git a/fiftyone/server/utils.py b/fiftyone/server/utils.py index eeaa7dc260..264394470d 100644 --- a/fiftyone/server/utils.py +++ b/fiftyone/server/utils.py @@ -49,6 +49,21 @@ def load_and_cache_dataset(name): return dataset +def cache_dataset(dataset): + """Caches the given dataset. + + This method ensures that subsequent calls to + :func:`fiftyone.core.dataset.load_dataset` in async calls will return this + dataset singleton. + + See :meth:`load_and_cache_dataset` for additional details. + + Args: + dataset: a :class:`fiftyone.core.dataset.Dataset` + """ + _cache[dataset.name] = dataset + + def change_sample_tags(sample_collection, changes): """Applies the changes to tags to all samples of the collection, if necessary. diff --git a/plugins/panels/model_evaluation.py b/plugins/panels/model_evaluation.py index b1bd95142a..a36592f204 100644 --- a/plugins/panels/model_evaluation.py +++ b/plugins/panels/model_evaluation.py @@ -317,16 +317,6 @@ def get_confusion_matrices(self, results): "lc_colorscale": lc_colorscale, } - def get_mask_targets(self, dataset, gt_field): - mask_targets = dataset.mask_targets.get(gt_field, None) - if mask_targets: - return mask_targets - - if dataset.default_mask_targets: - return dataset.default_mask_targets - - return None - def load_evaluation(self, ctx): view_state = ctx.panel.get_state("view") or {} eval_key = view_state.get("key") @@ -353,8 +343,8 @@ def load_evaluation(self, ctx): mask_targets = None if evaluation_type == "segmentation": - mask_targets = self.get_mask_targets(ctx.dataset, gt_field) - _init_segmentation_results(results, mask_targets) + mask_targets = _get_mask_targets(ctx.dataset, gt_field) + _init_segmentation_results(ctx.dataset, results, gt_field) metrics = results.metrics() per_class_metrics = self.get_per_class_metrics(info, results) @@ -591,6 +581,7 @@ def load_view(self, ctx): ) elif info.config.type == "segmentation": results = ctx.dataset.load_evaluation_results(eval_key) + _init_segmentation_results(ctx.dataset, results, gt_field) if results.ytrue_ids is None or results.ypred_ids is None: # Legacy format segmentations return @@ -600,6 +591,7 @@ def load_view(self, ctx): gt_field2 = gt_field results2 = ctx.dataset.load_evaluation_results(eval_key2) + _init_segmentation_results(ctx.dataset, results2, gt_field2) if results2.ytrue_ids is None or results2.ypred_ids is None: # Legacy format segmentations return @@ -681,11 +673,35 @@ def render(self, ctx): ) -def _init_segmentation_results(results, mask_targets): +def _get_mask_targets(dataset, gt_field): + mask_targets = dataset.mask_targets.get(gt_field, None) + if mask_targets: + return mask_targets + + if dataset.default_mask_targets: + return dataset.default_mask_targets + + return None + + +def _init_segmentation_results(dataset, results, gt_field): if results.ytrue_ids is None or results.ypred_ids is None: # Legacy format segmentations return + if getattr(results, "_classes_map", None): + # Already initialized + return + + # + # Ensure the dataset singleton is cached so that subsequent callbacks on + # this panel will use the same `dataset` and hence `results` + # + + import fiftyone.server.utils as fosu + + fosu.cache_dataset(dataset) + # # `results.classes` and App callbacks could contain any of the # following: @@ -698,6 +714,7 @@ def _init_segmentation_results(results, mask_targets): # classes_map = {c: i for i, c in enumerate(results.classes)} + mask_targets = _get_mask_targets(dataset, gt_field) if mask_targets is not None: # `str()` handles cases 1 and 2, and `.get(c, c)` handles case 3 mask_targets = {str(k): v for k, v in mask_targets.items()}