Skip to content

Commit

Permalink
Merge pull request #103 from phinate:ignore-nans
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 680607724
  • Loading branch information
PIXDev committed Sep 30, 2024
2 parents d60144e + c004366 commit b76b0b6
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 36 deletions.
62 changes: 49 additions & 13 deletions dm_pix/_src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@
# DO NOT REMOVE - Logging lib.


def mae(a: chex.Array, b: chex.Array) -> chex.Numeric:
def mae(
a: chex.Array,
b: chex.Array,
*,
ignore_nans: bool = False,
) -> chex.Numeric:
"""Returns the Mean Absolute Error between `a` and `b`.
Args:
a: First image (or set of images).
b: Second image (or set of images).
ignore_nans: If True, will ignore NaNs in the inputs.
Returns:
MAE between `a` and `b`.
Expand All @@ -41,15 +47,22 @@ def mae(a: chex.Array, b: chex.Array) -> chex.Numeric:
chex.assert_rank([a, b], {3, 4})
chex.assert_type([a, b], float)
chex.assert_equal_shape([a, b])
return jnp.abs(a - b).mean(axis=(-3, -2, -1))
mean_fn = jnp.nanmean if ignore_nans else jnp.mean
return mean_fn(jnp.abs(a - b), axis=(-3, -2, -1))


def mse(a: chex.Array, b: chex.Array) -> chex.Numeric:
def mse(
a: chex.Array,
b: chex.Array,
*,
ignore_nans: bool = False,
) -> chex.Numeric:
"""Returns the Mean Squared Error between `a` and `b`.
Args:
a: First image (or set of images).
b: Second image (or set of images).
ignore_nans: If True, will ignore NaNs in the inputs.
Returns:
MSE between `a` and `b`.
Expand All @@ -59,10 +72,16 @@ def mse(a: chex.Array, b: chex.Array) -> chex.Numeric:
chex.assert_rank([a, b], {3, 4})
chex.assert_type([a, b], float)
chex.assert_equal_shape([a, b])
return jnp.square(a - b).mean(axis=(-3, -2, -1))
mean_fn = jnp.nanmean if ignore_nans else jnp.mean
return mean_fn(jnp.square(a - b), axis=(-3, -2, -1))


def psnr(a: chex.Array, b: chex.Array) -> chex.Numeric:
def psnr(
a: chex.Array,
b: chex.Array,
*,
ignore_nans: bool = False,
) -> chex.Numeric:
"""Returns the Peak Signal-to-Noise Ratio between `a` and `b`.
Assumes that the dynamic range of the images (the difference between the
Expand All @@ -71,6 +90,7 @@ def psnr(a: chex.Array, b: chex.Array) -> chex.Numeric:
Args:
a: First image (or set of images).
b: Second image (or set of images).
ignore_nans: If True, will ignore NaNs in the inputs.
Returns:
PSNR in decibels between `a` and `b`.
Expand All @@ -80,15 +100,21 @@ def psnr(a: chex.Array, b: chex.Array) -> chex.Numeric:
chex.assert_rank([a, b], {3, 4})
chex.assert_type([a, b], float)
chex.assert_equal_shape([a, b])
return -10.0 * jnp.log(mse(a, b)) / jnp.log(10.0)
return -10.0 * jnp.log(mse(a, b, ignore_nans=ignore_nans)) / jnp.log(10.0)


def rmse(a: chex.Array, b: chex.Array) -> chex.Numeric:
def rmse(
a: chex.Array,
b: chex.Array,
*,
ignore_nans: bool = False,
) -> chex.Array:
"""Returns the Root Mean Squared Error between `a` and `b`.
Args:
a: First image (or set of images).
b: Second image (or set of images).
ignore_nans: If True, will ignore NaNs in the inputs.
Returns:
RMSE between `a` and `b`.
Expand All @@ -98,10 +124,15 @@ def rmse(a: chex.Array, b: chex.Array) -> chex.Numeric:
chex.assert_rank([a, b], {3, 4})
chex.assert_type([a, b], float)
chex.assert_equal_shape([a, b])
return jnp.sqrt(mse(a, b))
return jnp.sqrt(mse(a, b, ignore_nans=ignore_nans))


def simse(a: chex.Array, b: chex.Array) -> chex.Numeric:
def simse(
a: chex.Array,
b: chex.Array,
*,
ignore_nans: bool = False,
) -> chex.Numeric:
"""Returns the Scale-Invariant Mean Squared Error between `a` and `b`.
For each image pair, a scaling factor for `b` is computed as the solution to
Expand All @@ -120,6 +151,7 @@ def simse(a: chex.Array, b: chex.Array) -> chex.Numeric:
Args:
a: First image (or set of images).
b: Second image (or set of images).
ignore_nans: If True, will ignore NaNs in the inputs.
Returns:
SIMSE between `a` and `b`.
Expand All @@ -130,10 +162,11 @@ def simse(a: chex.Array, b: chex.Array) -> chex.Numeric:
chex.assert_type([a, b], float)
chex.assert_equal_shape([a, b])

a_dot_b = (a * b).sum(axis=(-3, -2, -1), keepdims=True)
b_dot_b = (b * b).sum(axis=(-3, -2, -1), keepdims=True)
sum_fn = jnp.nansum if ignore_nans else jnp.sum
a_dot_b = sum_fn((a * b), axis=(-3, -2, -1), keepdims=True)
b_dot_b = sum_fn((b * b), axis=(-3, -2, -1), keepdims=True)
alpha = a_dot_b / b_dot_b
return mse(a, alpha * b)
return mse(a, alpha * b, ignore_nans=ignore_nans)


def ssim(
Expand All @@ -148,6 +181,7 @@ def ssim(
return_map: bool = False,
precision=jax.lax.Precision.HIGHEST,
filter_fn: Optional[Callable[[chex.Array], chex.Array]] = None,
ignore_nans: bool = False,
) -> chex.Numeric:
"""Computes the structural similarity index (SSIM) between image pairs.
Expand Down Expand Up @@ -176,6 +210,7 @@ def ssim(
filter_fn: An optional argument for overriding the filter function used by
SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size
and filter_sigma.
ignore_nans: If True, will ignore NaNs in the inputs.
Returns:
Each image's mean SSIM, or a tensor of individual values if `return_map`.
Expand Down Expand Up @@ -252,5 +287,6 @@ def filter_fn_x(z):
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
ssim_map = numer / denom
ssim_value = jnp.mean(ssim_map, list(range(-3, 0)))
mean_fn = jnp.nanmean if ignore_nans else jnp.mean
ssim_value = mean_fn(ssim_map, axis=tuple(range(-3, 0)))
return ssim_map if return_map else ssim_value
116 changes: 93 additions & 23 deletions dm_pix/_src/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dm_pix._src.metrics."""

import functools

from absl.testing import absltest
Expand All @@ -32,29 +30,55 @@ def setUp(self):
self._img1 = jax.random.uniform(
key1,
shape=(4, 32, 32, 3),
minval=0.,
maxval=1.,
minval=0.0,
maxval=1.0,
)
self._img2 = jax.random.uniform(
key2,
shape=(4, 32, 32, 3),
minval=0.,
maxval=1.,
minval=0.0,
maxval=1.0,
)

@chex.all_variants
def test_psnr_match(self):
psnr = self.variant(metrics.psnr)
values_jax = psnr(self._img1, self._img2)
values_tf = tf.image.psnr(self._img1, self._img2, max_val=1.).numpy()

values_tf = tf.image.psnr(self._img1, self._img2, max_val=1.0).numpy()

np.testing.assert_allclose(values_jax, values_tf, rtol=1e-3, atol=1e-3)

@chex.all_variants
def test_psnr_ignore_nans(self):
psnr = self.variant(functools.partial(metrics.psnr, ignore_nans=True))

values_jax_nan = psnr(
self._img1.at[:, 0, 0, 0].set(np.nan),
self._img1.at[:, 0, 0, 0].set(np.nan),
)

assert not np.any(np.isnan(values_jax_nan))

@chex.all_variants
def test_simse_invariance(self):
simse = self.variant(metrics.simse)

simse_jax = simse(self._img1, self._img1 * 2.0)

np.testing.assert_allclose(simse_jax, np.zeros(4), rtol=1e-6, atol=1e-6)

@chex.all_variants
def test_simse_ignore_nans(self):
simse = self.variant(functools.partial(metrics.simse, ignore_nans=True))

simse_jax_nan = simse(
self._img1.at[:, 0, 0, 0].set(np.nan),
self._img1.at[:, 0, 0, 0].set(np.nan),
)

assert not np.any(np.isnan(simse_jax_nan))


class SSIMTests(chex.TestCase, absltest.TestCase):

Expand All @@ -65,15 +89,25 @@ def test_ssim_golden(self):
key = jax.random.PRNGKey(0)
for shape in ((2, 12, 12, 3), (12, 12, 3), (2, 12, 15, 3), (17, 12, 3)):
for _ in range(4):
(max_val_key, img0_key, img1_key, filter_size_key, filter_sigma_key,
k1_key, k2_key, key) = jax.random.split(key, 8)
max_val = jax.random.uniform(max_val_key, minval=0.1, maxval=3.)
(
max_val_key,
img0_key,
img1_key,
filter_size_key,
filter_sigma_key,
k1_key,
k2_key,
key,
) = jax.random.split(key, 8)
max_val = jax.random.uniform(max_val_key, minval=0.1, maxval=3.0)
img0 = max_val * jax.random.uniform(img0_key, shape=shape)
img1 = max_val * jax.random.uniform(img1_key, shape=shape)
filter_size = jax.random.randint(
filter_size_key, shape=(), minval=1, maxval=10)
filter_size_key, shape=(), minval=1, maxval=10
)
filter_sigma = jax.random.uniform(
filter_sigma_key, shape=(), minval=0.1, maxval=10.)
filter_sigma_key, shape=(), minval=0.1, maxval=10.0
)
k1 = jax.random.uniform(k1_key, shape=(), minval=0.001, maxval=0.1)
k2 = jax.random.uniform(k2_key, shape=(), minval=0.001, maxval=0.1)

Expand All @@ -84,7 +118,8 @@ def test_ssim_golden(self):
filter_size=filter_size,
filter_sigma=filter_sigma,
k1=k1,
k2=k2).numpy()
k2=k2,
).numpy()
for return_map in [False, True]:
ssim_fn = self.variant(
functools.partial(
Expand All @@ -95,18 +130,19 @@ def test_ssim_golden(self):
k1=k1,
k2=k2,
return_map=return_map,
))
)
)

ssim = ssim_fn(img0, img1)

if not return_map:
np.testing.assert_allclose(ssim, ssim_gt, atol=1e-5, rtol=1e-5)
else:
np.testing.assert_allclose(
np.mean(ssim, list(range(-3, 0))),
ssim_gt,
atol=1e-5,
rtol=1e-5)
self.assertLessEqual(np.max(ssim), 1.)
self.assertGreaterEqual(np.min(ssim), -1.)
np.mean(ssim, list(range(-3, 0))), ssim_gt, atol=1e-5, rtol=1e-5
)
self.assertLessEqual(np.max(ssim), 1.0)
self.assertGreaterEqual(np.min(ssim), -1.0)

@chex.all_variants
def test_ssim_lowerbound(self):
Expand All @@ -118,22 +154,56 @@ def test_ssim_lowerbound(self):
ssim_fn = self.variant(
functools.partial(
metrics.ssim,
max_val=1.,
max_val=1.0,
filter_size=filter_size,
filter_sigma=1.5,
k1=eps,
k2=eps,
))
)
)

ssim = ssim_fn(img, -img)
np.testing.assert_allclose(ssim, -np.ones_like(ssim), atol=1E-5, rtol=1E-5)

np.testing.assert_allclose(ssim, -np.ones_like(ssim), atol=1e-5, rtol=1e-5)

@chex.all_variants
def test_ssim_finite_grad(self):
"""Test that SSIM produces a finite gradient on large flat regions."""
img = np.zeros((64, 64, 3))

grad = self.variant(jax.grad(metrics.ssim))(img, img)

np.testing.assert_equal(grad, np.zeros_like(grad))

@chex.all_variants
def test_ssim_ignore_nans(self):
"""Test that SSIM ignores NaNs."""
ssim_fn = self.variant(
functools.partial(
metrics.ssim,
max_val=1.0,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03,
ignore_nans=True,
)
)
key = jax.random.PRNGKey(0)
_, key1 = jax.random.split(key)
img = jax.random.uniform(
key1,
shape=(4, 32, 32, 3),
minval=0.0,
maxval=1.0,
)

ssim = ssim_fn(
img.at[:, 0, 0, 0].set(np.nan), img.at[:, 0, 0, 0].set(np.nan)
)

assert not np.any(np.isnan(ssim))


if __name__ == "__main__":
absltest.main()

0 comments on commit b76b0b6

Please sign in to comment.