Skip to content

Commit

Permalink
feat(w2): add W2 distance method for images
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnec committed Nov 4, 2021
1 parent c508a57 commit 552010d
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions color_matcher/mvgd_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,28 @@ def w2_dist(mu_a: np.ndarray, mu_b: np.ndarray, cov_a: np.ndarray, cov_b: np.nda

return float(mean_dist + vars_dist)

def w2_img_dist(self, img_a: np.ndarray, img_b:np.ndarray):

"""
Wasserstein-2 image distance metric is a similarity measure for Gaussian distributions
:param img_a: Image array *a*
:param img_b: Image array *b*
:type img_a: :class:`~numpy:numpy.ndarray`
:type img_b: :class:`~numpy:numpy.ndarray`
:return: **scalar**: Wasserstein-2 image metric as a scalar
:rtype: float
"""

img_src, img_ref = img_a, img_b
mu_a, mu_b = np.mean(img_src, axis=(0, 1)), np.mean(img_ref, axis=(0, 1))
cov_a, cov_b = np.cov(img_src.reshape(-1, 3).T), np.cov(img_ref.reshape(-1, 3).T)
w2_img_dist = self.w2_dist(mu_a, mu_b, cov_a, cov_b)

return w2_img_dist

def check_dims(self):
"""
Catch error for wrong color channel number (e.g., gray scale image)
Expand Down

0 comments on commit 552010d

Please sign in to comment.