From 4655651bcbc102c54f348aab4db1d55347e997af Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 29 May 2024 11:56:01 -0700 Subject: [PATCH] Fix predictor --- docs/config.md | 16 ++++++------- sleap_nn/inference/predictors.py | 40 ++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/docs/config.md b/docs/config.md index 67d5c510..c3a775a5 100644 --- a/docs/config.md +++ b/docs/config.md @@ -121,14 +121,14 @@ The config file has four main sections: convolutions may be able to learn richer or more complex upsampling to recover details from higher scales. Default: True. - `head_configs`: (List[dict]) List of heads in the model. For eg, BottomUp model has both 'MultiInstanceConfmapsHead' and 'PartAffinityFieldsHead' heads. - - `head_type`: (str) Name of the head. Supported values are 'SingleInstanceConfmapsHead', 'CentroidConfmapsHead', 'CenteredInstanceConfmapsHead', 'MultiInstanceConfmapsHead', 'PartAffinityFieldsHead', 'ClassMapsHead', 'ClassVectorsHead', 'OffsetRefinementHead' - - `head_config`: - - `part_names`: (List[str]) Text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. - - `edges`: (List[str]) **Note**: Only for 'PartAffinityFieldsHead'. List of indices `(src, dest)` that form an edge. - - `anchor_part`: (int) **Note**: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. - - `sigma`: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. - - `output_stride`: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. - - `loss_weight`: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models. + - `head_type`: (str) Name of the head. Supported values are 'SingleInstanceConfmapsHead', 'CentroidConfmapsHead', 'CenteredInstanceConfmapsHead', 'MultiInstanceConfmapsHead', 'PartAffinityFieldsHead', 'ClassMapsHead', 'ClassVectorsHead', 'OffsetRefinementHead' + - `head_config`: + - `part_names`: (List[str]) Text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'. + - `edges`: (List[str]) **Note**: Only for 'PartAffinityFieldsHead'. List of indices `(src, dest)` that form an edge. + - `anchor_part`: (int) **Note**: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image. + - `sigma`: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied. + - `output_stride`: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. + - `loss_weight`: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models. - `trainer_config`: - `train_data_loader`: diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index 984edb35..b6de5937 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -241,11 +241,16 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: yield output except Exception as e: - raise Exception(f"Error in VideoReader: {e}") + raise Exception(f"Error in VideoReader during data processing: {e}") finally: self.pipeline.join() + def predict( + self, + make_labels: bool = True, + save_path: str = None, + ) -> Union[List[Dict[str, np.ndarray]], sio.Labels]: """Run inference on a data source. Args: @@ -259,6 +264,20 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: otherwise a list of dictionaries containing batches of numpy arrays with the raw results. """ + # Initialize inference loop generator. + generator = self._predict_generator() + + if make_labels: + # Create SLEAP data structures from the predictions. + pred_labels = self._make_labeled_frames_from_generator(generator) + if save_path: + sio.io.slp.write_labels(save_path, pred_labels) + return pred_labels + + else: + # Just return the raw results. + return list(generator) + @abstractmethod def _make_labeled_frames_from_generator(self, generator) -> sio.Labels: """Create `sio.Labels` object from the predictions.""" @@ -304,8 +323,12 @@ def _initialize_inference_model(self): self.centroid_config.inference_config.data["skeletons"] = ( self.centroid_config.data_config.skeletons ) + + # if both centroid and centered-instance model are provided, set return crops to True if self.confmap_model: return_crops = True + + # initialize centroid crop layer centroid_crop_layer = CentroidCrop( torch_model=self.centroid_model, peak_threshold=self.centroid_config.inference_config.peak_threshold, @@ -345,7 +368,7 @@ def _initialize_inference_model(self): max_stride=max_stride, ) - # Initialize the inference model with centroid and conf map layers + # Initialize the inference model with centroid and instance peak layers self.inference_model = TopDownInferenceModel( centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer ) @@ -430,9 +453,13 @@ class (doesn't return a pipeline) and the Thread is started in provider is LabelsReader. """ self.provider = self.data_config.provider + + # LabelsReader provider if self.provider == "LabelsReader": provider = LabelsReader instances_key = True + + # no need of `instances` key for Centered-instance model if self.centroid_config and self.confmap_config: instances_key = False @@ -441,6 +468,7 @@ class (doesn't return a pipeline) and the Thread is started in ) self.videos = data_provider.labels.videos + pipeline = Normalizer(data_provider, is_rgb=self.data_config.is_rgb) pipeline = SizeMatcher( pipeline, @@ -486,6 +514,7 @@ class (doesn't return a pipeline) and the Thread is started in return self.pipeline + # VideoReader provider elif self.provider == "VideoReader": if self.centroid_config is None: raise ValueError( @@ -894,14 +923,19 @@ class BottomUpPredictor(Predictor): def _initialize_inference_model(self): """Initialize the inference model from the trained models and configuration.""" + # get the index of pafs head configs paf_idx = [ x.head_type == "PartAffinityFieldsHead" for x in self.bottomup_config.model_config.head_configs ].index(True) + + # get the index of confmap head configs confmaps_idx = [ x.head_type == "MultiInstanceConfmapsHead" for x in self.bottomup_config.model_config.head_configs ].index(True) + + # initialize the paf scorer paf_scorer = PAFScorer.from_config( config=OmegaConf.create( { @@ -940,6 +974,8 @@ def _initialize_inference_model(self): else self.min_line_scores ), ) + + # initialize the BottomUpInferenceModel self.inference_model = BottomUpInferenceModel( torch_model=self.bottomup_model, paf_scorer=paf_scorer,