Skip to content

Commit

Permalink
Fix predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed May 29, 2024
1 parent 3288030 commit 4655651
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
16 changes: 8 additions & 8 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ The config file has four main sections:
convolutions may be able to learn richer or more complex upsampling to
recover details from higher scales. Default: True.
- `head_configs`: (List[dict]) List of heads in the model. For eg, BottomUp model has both 'MultiInstanceConfmapsHead' and 'PartAffinityFieldsHead' heads.
- `head_type`: (str) Name of the head. Supported values are 'SingleInstanceConfmapsHead', 'CentroidConfmapsHead', 'CenteredInstanceConfmapsHead', 'MultiInstanceConfmapsHead', 'PartAffinityFieldsHead', 'ClassMapsHead', 'ClassVectorsHead', 'OffsetRefinementHead'
- `head_config`:
- `part_names`: (List[str]) Text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
- `edges`: (List[str]) **Note**: Only for 'PartAffinityFieldsHead'. List of indices `(src, dest)` that form an edge.
- `anchor_part`: (int) **Note**: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image.
- `sigma`: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
- `output_stride`: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
- `loss_weight`: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models.
- `head_type`: (str) Name of the head. Supported values are 'SingleInstanceConfmapsHead', 'CentroidConfmapsHead', 'CenteredInstanceConfmapsHead', 'MultiInstanceConfmapsHead', 'PartAffinityFieldsHead', 'ClassMapsHead', 'ClassVectorsHead', 'OffsetRefinementHead'
- `head_config`:
- `part_names`: (List[str]) Text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. If not specified, all body parts in the skeleton will be used. This config does not apply for 'PartAffinityFieldsHead'.
- `edges`: (List[str]) **Note**: Only for 'PartAffinityFieldsHead'. List of indices `(src, dest)` that form an edge.
- `anchor_part`: (int) **Note**: Only for 'CenteredInstanceConfmapsHead'. Index of the anchor node to use as the anchor point. If None, the midpoint of the bounding box of all visible instance points will be used as the anchor. The bounding box midpoint will also be used if the anchor part is specified but not visible in the instance. Setting a reliable anchor point can significantly improve topdown model accuracy as they benefit from a consistent geometry of the body parts relative to the center of the image.
- `sigma`: (float) Spread of the Gaussian distribution of the confidence maps as a scalar float. Smaller values are more precise but may be difficult to learn as they have a lower density within the image space. Larger values are easier to learn but are less precise with respect to the peak coordinate. This spread is in units of pixels of the model input image, i.e., the image resolution after any input scaling is applied.
- `output_stride`: (float) The stride of the output confidence maps relative to the input image. This is the reciprocal of the resolution, e.g., an output stride of 2 results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution.
- `loss_weight`: (float) Scalar float used to weigh the loss term for this head during training. Increase this to encourage the optimization to focus on improving this specific output in multi-head models.

- `trainer_config`:
- `train_data_loader`:
Expand Down
40 changes: 38 additions & 2 deletions sleap_nn/inference/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,16 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
yield output

except Exception as e:
raise Exception(f"Error in VideoReader: {e}")
raise Exception(f"Error in VideoReader during data processing: {e}")

finally:
self.pipeline.join()

def predict(
self,
make_labels: bool = True,
save_path: str = None,
) -> Union[List[Dict[str, np.ndarray]], sio.Labels]:
"""Run inference on a data source.
Args:
Expand All @@ -259,6 +264,20 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
otherwise a list of dictionaries containing batches of numpy arrays with the
raw results.
"""
# Initialize inference loop generator.
generator = self._predict_generator()

if make_labels:
# Create SLEAP data structures from the predictions.
pred_labels = self._make_labeled_frames_from_generator(generator)
if save_path:
sio.io.slp.write_labels(save_path, pred_labels)
return pred_labels

else:
# Just return the raw results.
return list(generator)

@abstractmethod
def _make_labeled_frames_from_generator(self, generator) -> sio.Labels:
"""Create `sio.Labels` object from the predictions."""
Expand Down Expand Up @@ -304,8 +323,12 @@ def _initialize_inference_model(self):
self.centroid_config.inference_config.data["skeletons"] = (
self.centroid_config.data_config.skeletons
)

# if both centroid and centered-instance model are provided, set return crops to True
if self.confmap_model:
return_crops = True

# initialize centroid crop layer
centroid_crop_layer = CentroidCrop(
torch_model=self.centroid_model,
peak_threshold=self.centroid_config.inference_config.peak_threshold,
Expand Down Expand Up @@ -345,7 +368,7 @@ def _initialize_inference_model(self):
max_stride=max_stride,
)

# Initialize the inference model with centroid and conf map layers
# Initialize the inference model with centroid and instance peak layers
self.inference_model = TopDownInferenceModel(
centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer
)
Expand Down Expand Up @@ -430,9 +453,13 @@ class (doesn't return a pipeline) and the Thread is started in
provider is LabelsReader.
"""
self.provider = self.data_config.provider

# LabelsReader provider
if self.provider == "LabelsReader":
provider = LabelsReader
instances_key = True

# no need of `instances` key for Centered-instance model
if self.centroid_config and self.confmap_config:
instances_key = False

Expand All @@ -441,6 +468,7 @@ class (doesn't return a pipeline) and the Thread is started in
)

self.videos = data_provider.labels.videos

pipeline = Normalizer(data_provider, is_rgb=self.data_config.is_rgb)
pipeline = SizeMatcher(
pipeline,
Expand Down Expand Up @@ -486,6 +514,7 @@ class (doesn't return a pipeline) and the Thread is started in

return self.pipeline

# VideoReader provider
elif self.provider == "VideoReader":
if self.centroid_config is None:
raise ValueError(
Expand Down Expand Up @@ -894,14 +923,19 @@ class BottomUpPredictor(Predictor):

def _initialize_inference_model(self):
"""Initialize the inference model from the trained models and configuration."""
# get the index of pafs head configs
paf_idx = [
x.head_type == "PartAffinityFieldsHead"
for x in self.bottomup_config.model_config.head_configs
].index(True)

# get the index of confmap head configs
confmaps_idx = [
x.head_type == "MultiInstanceConfmapsHead"
for x in self.bottomup_config.model_config.head_configs
].index(True)

# initialize the paf scorer
paf_scorer = PAFScorer.from_config(
config=OmegaConf.create(
{
Expand Down Expand Up @@ -940,6 +974,8 @@ def _initialize_inference_model(self):
else self.min_line_scores
),
)

# initialize the BottomUpInferenceModel
self.inference_model = BottomUpInferenceModel(
torch_model=self.bottomup_model,
paf_scorer=paf_scorer,
Expand Down

0 comments on commit 4655651

Please sign in to comment.