Skip to content

Commit

Permalink
[Fix] Fix psnr,snr,ssim,mae and mse fail to compute on videos (#89)
Browse files Browse the repository at this point in the history
* [Fix] fix psnr,snr,ssim,mae and mse fail to compute on videos

* fix snr

* [Fix] fix psnr,snr,ssim,mae and mse fail to compute on videos

* fix snr

* fix comments

* add ut
  • Loading branch information
Z-Fran authored Feb 13, 2023
1 parent cf34dbc commit 5a3647c
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 25 deletions.
17 changes: 14 additions & 3 deletions mmeval/metrics/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,21 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr
f'Image shapes are different: \
{groundtruth.shape}, {prediction.shape}.')
if masks is None:
self._results.append(self.compute_mae(prediction, groundtruth))
result = self.compute_mae(prediction, groundtruth)
else:
self._results.append(
self.compute_mae(prediction, groundtruth, masks[i]))
# when prediction is a image
if len(prediction.shape) <= 3:
result = self.compute_mae(prediction, groundtruth,
masks[i])
# when prediction is a video
else:
result_sum = 0
for j in range(prediction.shape[0]):
result_sum += self.compute_mae(prediction[j],
groundtruth[j],
masks[i][j])
result = result_sum / prediction.shape[0]
self._results.append(result)

def compute_metric(self, results: List[np.float32]) -> Dict[str, float]:
"""Compute the MeanAbsoluteError metric.
Expand Down
17 changes: 14 additions & 3 deletions mmeval/metrics/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,21 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr
f'Image shapes are different: \
{groundtruth.shape}, {prediction.shape}.')
if masks is None:
self._results.append(self.compute_mse(prediction, groundtruth))
result = self.compute_mse(prediction, groundtruth)
else:
self._results.append(
self.compute_mse(prediction, groundtruth, masks[i]))
# when prediction is a image
if len(prediction.shape) <= 3:
result = self.compute_mse(prediction, groundtruth,
masks[i])
# when prediction is a video
else:
result_sum = 0
for j in range(prediction.shape[0]):
result_sum += self.compute_mse(prediction[j],
groundtruth[j],
masks[i][j])
result = result_sum / prediction.shape[0]
self._results.append(result)

def compute_metric(self, results: List[np.float32]) -> Dict[str, float]:
"""Compute the MeanSquaredError metric.
Expand Down
15 changes: 11 additions & 4 deletions mmeval/metrics/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,22 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr
crop_border=self.crop_border,
input_order=self.input_order,
convert_to=self.convert_to,
channel_order=self.channel_order)
channel_order=channel_order)
prediction = reorder_and_crop(
prediction,
crop_border=self.crop_border,
input_order=self.input_order,
convert_to=self.convert_to,
channel_order=self.channel_order)

self._results.append(self.compute_psnr(prediction, groundtruth))
channel_order=channel_order)

if len(prediction.shape) == 3:
prediction = np.expand_dims(prediction, axis=0)
groundtruth = np.expand_dims(groundtruth, axis=0)
_psnr_score = []
for i in range(prediction.shape[0]):
_psnr_score.append(
self.compute_psnr(prediction[i], groundtruth[i]))
self._results.append(np.array(_psnr_score).mean())

def compute_metric(self, results: List[np.float64]) -> Dict[str, float]:
"""Compute the PeakSignalNoiseRatio metric.
Expand Down
30 changes: 19 additions & 11 deletions mmeval/metrics/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,31 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr
"""
channel_order = self.channel_order \
if channel_order is None else channel_order
for pred, gt in zip(predictions, groundtruths):
assert gt.shape == pred.shape, (
f'Image shapes are different: {gt.shape}, {pred.shape}.')
gt = reorder_and_crop(
gt,
for prediction, groundtruth in zip(predictions, groundtruths):
assert groundtruth.shape == prediction.shape, (
f'Image shapes are different: \
{groundtruth.shape}, {prediction.shape}.')
groundtruth = reorder_and_crop(
groundtruth,
crop_border=self.crop_border,
input_order=self.input_order,
convert_to=self.convert_to,
channel_order=self.channel_order)
pred = reorder_and_crop(
pred,
channel_order=channel_order)
prediction = reorder_and_crop(
prediction,
crop_border=self.crop_border,
input_order=self.input_order,
convert_to=self.convert_to,
channel_order=self.channel_order)

self._results.append(self.compute_snr(pred, gt))
channel_order=channel_order)

if len(prediction.shape) == 3:
prediction = np.expand_dims(prediction, axis=0)
groundtruth = np.expand_dims(groundtruth, axis=0)
_snr_score = []
for i in range(prediction.shape[0]):
_snr_score.append(
self.compute_snr(prediction[i], groundtruth[i]))
self._results.append(np.array(_snr_score).mean())

def compute_metric(self, results: List[np.float64]) -> Dict[str, float]:
"""Compute the SignalNoiseRatio metric.
Expand Down
9 changes: 7 additions & 2 deletions mmeval/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,14 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr
convert_to=self.convert_to,
channel_order=channel_order)

if len(pred.shape) == 3:
pred = np.expand_dims(pred, axis=0)
gt = np.expand_dims(gt, axis=0)
_ssim_score = []
for i in range(pred.shape[2]):
_ssim_score.append(self.compute_ssim(pred[..., i], gt[..., i]))
for i in range(pred.shape[0]):
for j in range(pred.shape[3]):
_ssim_score.append(
self.compute_ssim(pred[i][..., j], gt[i][..., j]))
self._results.append(np.array(_ssim_score).mean())

def compute_metric(self, results: List[np.float64]) -> Dict[str, float]:
Expand Down
8 changes: 8 additions & 0 deletions mmeval/metrics/utils/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ def reorder_and_crop(img: np.ndarray,
np.array: The transformation results.
"""

if len(img.shape) == 4:
result = []
for i in range(img.shape[0]):
result.append(
reorder_and_crop(img[i], crop_border, input_order, convert_to,
channel_order))
return np.array(result).astype(np.float64)

img = reorder_image(img, input_order=input_order)
img = img.astype(np.float32)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_metrics/test_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


def test_mae():
# test image input
preds = [np.ones((32, 32, 3))]
gts = [np.ones((32, 32, 3)) * 2]
mask = np.ones((32, 32, 3)) * 2
Expand All @@ -19,3 +20,19 @@ def test_mae():
mae_results = mae(preds, gts, [mask])
assert isinstance(mae_results, dict)
np.testing.assert_almost_equal(mae_results['mae'], 0.003921568627)

# test video input
preds = [np.ones((5, 32, 32, 3))]
gts = [np.ones((5, 32, 32, 3)) * 2]
mask = np.ones((5, 32, 32, 3)) * 2
mask[:, :16] *= 0

mae = MeanAbsoluteError()
mae_results = mae(preds, gts)
assert isinstance(mae_results, dict)
np.testing.assert_almost_equal(mae_results['mae'], 0.003921568627)

mae = MeanAbsoluteError()
mae_results = mae(preds, gts, [mask])
assert isinstance(mae_results, dict)
np.testing.assert_almost_equal(mae_results['mae'], 0.003921568627)
17 changes: 17 additions & 0 deletions tests/test_metrics/test_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


def test_mse():
# test image input
preds = [np.ones((32, 32, 3))]
gts = [np.ones((32, 32, 3)) * 2]
mask = np.ones((32, 32, 3)) * 2
Expand All @@ -19,3 +20,19 @@ def test_mse():
mse_results = mse(preds, gts)
assert isinstance(mse_results, dict)
np.testing.assert_almost_equal(mse_results['mse'], 0.000015378700496)

# test video input
preds = [np.ones((5, 32, 32, 3))]
gts = [np.ones((5, 32, 32, 3)) * 2]
mask = np.ones((5, 32, 32, 3)) * 2
mask[:, :16] *= 0

mse = MeanSquaredError()
mse_results = mse(preds, gts)
assert isinstance(mse_results, dict)
np.testing.assert_almost_equal(mse_results['mse'], 0.000015378700496)

mse = MeanSquaredError()
mse_results = mse(preds, gts)
assert isinstance(mse_results, dict)
np.testing.assert_almost_equal(mse_results['mse'], 0.000015378700496)
4 changes: 3 additions & 1 deletion tests/test_metrics/test_psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def test_psnr_init():
}, [np.ones((32, 32, 3))], [np.ones(
(32, 32, 3)) * 2], 49.45272242415597),
({}, [np.ones((32, 32))], [np.ones((32, 32))], float('inf')),
({}, [np.zeros((32, 32))], [np.ones((32, 32)) * 255], 0)])
({}, [np.zeros((32, 32))], [np.ones((32, 32)) * 255], 0),
({}, [np.ones((5, 3, 32, 32))], [np.ones(
(5, 3, 32, 32)) * 2], 48.1308036086791)])
def test_psnr(metric_kwargs, img1, img2, results):
psnr = PeakSignalNoiseRatio(**metric_kwargs)
psnr_results = psnr(img1, img2)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_metrics/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def test_snr_init():
({'input_order': 'HWC', 'convert_to': 'Y'}, [np.ones((32, 32, 3))],
[np.ones((32, 32, 3)) * 2], 26.290039980499536),
({}, [np.ones((32, 32))], [np.ones((32, 32))], float('inf')),
({}, [np.zeros((32, 32))], [np.ones((32, 32))], 0)
({}, [np.zeros((32, 32))], [np.ones((32, 32))], 0),
({}, [np.ones((5, 3, 32, 32))], [np.ones((5, 3, 32, 32)) * 2],
6.020599913279624)
]
)
def test_snr(metric_kwargs, img1, img2, results):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_metrics/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def test_ssim_init():
[np.ones((3, 32, 32)) * 2], 0.9130623),
({'convert_to': 'Y', 'input_order': 'HWC'}, [np.ones((32, 32, 3))],
[np.ones((32, 32, 3)) * 2], 0.9987801),
({}, [np.ones((5, 3, 32, 32))], [np.ones((5, 3, 32, 32)) * 2],
0.9130623)
]
)
def test_ssim(metric_kwargs, img1, img2, results):
Expand Down

0 comments on commit 5a3647c

Please sign in to comment.