Skip to content

Commit

Permalink
Simplified ScipyDistances
Browse files Browse the repository at this point in the history
Also enables kwargs to be passed to scipy's cdist
  • Loading branch information
javiber committed Jan 3, 2023
1 parent c316914 commit 9d753db
Showing 1 changed file with 6 additions and 29 deletions.
35 changes: 6 additions & 29 deletions norfair/distances.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Predefined distances"""
from abc import ABC, abstractmethod
from functools import partial
from logging import warning
from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union

Expand Down Expand Up @@ -39,7 +40,6 @@ def get_distances(
np.ndarray
A matrix containing the distances between objects and candidates.
"""
pass


class ScalarDistance(Distance):
Expand Down Expand Up @@ -219,37 +219,16 @@ class ScipyDistance(VectorizedDistance):
Defines the specific Scipy metric to use to calculate the pairwise distances between
new candidates and objects.
Other kwargs are passed through to cdist
See Also
--------
[`scipy.spatial.distance.cdist`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html)
"""

def __init__(
self,
metric: str = "euclidean",
):
def __init__(self, metric: str = "euclidean", **kwargs):
self.metric = metric

def _compute_distance(
self, stacked_candidates: np.ndarray, stacked_objects: np.ndarray
) -> np.ndarray:
"""
Method that computes the pairwise distances between new candidates and objects.
It is intended to use the entire vectors to compare to each other in a single operation.
Parameters
----------
stacked_candidates : np.ndarray
np.ndarray containing a stack of candidates to be compared with the stacked_objects.
stacked_objects : np.ndarray
np.ndarray containing a stack of objects to be compared with the stacked_objects.
Returns
-------
np.ndarray
A matrix containing the distances between objects and candidates.
"""
return cdist(stacked_candidates, stacked_objects, metric=self.metric)
super().__init__(distance_function=partial(cdist, metric=self.metric, **kwargs))


def frobenius(detection: "Detection", tracked_object: "TrackedObject") -> float:
Expand Down Expand Up @@ -363,9 +342,7 @@ def _validate_bboxes(bboxes: np.ndarray):
), f"Bounding boxes must be defined as np.array with (N, 4) shape, {bboxes} given"

if not (all(bboxes[:, 0] < bboxes[:, 2]) and all(bboxes[:, 1] < bboxes[:, 3])):
warning(
f"Incorrect bounding boxes. Check that x_min < x_max and y_min < y_max."
)
warning("Incorrect bounding boxes. Check that x_min < x_max and y_min < y_max.")


def iou(candidates: np.ndarray, objects: np.ndarray) -> np.ndarray:
Expand Down

0 comments on commit 9d753db

Please sign in to comment.