diff --git a/mmeval/metrics/mae.py b/mmeval/metrics/mae.py index 1acb7cba..ac6ef278 100644 --- a/mmeval/metrics/mae.py +++ b/mmeval/metrics/mae.py @@ -53,10 +53,21 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr f'Image shapes are different: \ {groundtruth.shape}, {prediction.shape}.') if masks is None: - self._results.append(self.compute_mae(prediction, groundtruth)) + result = self.compute_mae(prediction, groundtruth) else: - self._results.append( - self.compute_mae(prediction, groundtruth, masks[i])) + # when prediction is a image + if len(prediction.shape) <= 3: + result = self.compute_mae(prediction, groundtruth, + masks[i]) + # when prediction is a video + else: + result_sum = 0 + for j in range(prediction.shape[0]): + result_sum += self.compute_mae(prediction[j], + groundtruth[j], + masks[i][j]) + result = result_sum / prediction.shape[0] + self._results.append(result) def compute_metric(self, results: List[np.float32]) -> Dict[str, float]: """Compute the MeanAbsoluteError metric. diff --git a/mmeval/metrics/mse.py b/mmeval/metrics/mse.py index f2a3ab57..eab25309 100644 --- a/mmeval/metrics/mse.py +++ b/mmeval/metrics/mse.py @@ -52,10 +52,21 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr f'Image shapes are different: \ {groundtruth.shape}, {prediction.shape}.') if masks is None: - self._results.append(self.compute_mse(prediction, groundtruth)) + result = self.compute_mse(prediction, groundtruth) else: - self._results.append( - self.compute_mse(prediction, groundtruth, masks[i])) + # when prediction is a image + if len(prediction.shape) <= 3: + result = self.compute_mse(prediction, groundtruth, + masks[i]) + # when prediction is a video + else: + result_sum = 0 + for j in range(prediction.shape[0]): + result_sum += self.compute_mse(prediction[j], + groundtruth[j], + masks[i][j]) + result = result_sum / prediction.shape[0] + self._results.append(result) def compute_metric(self, results: List[np.float32]) -> Dict[str, float]: """Compute the MeanSquaredError metric. diff --git a/mmeval/metrics/psnr.py b/mmeval/metrics/psnr.py index a4476d77..c462c72f 100644 --- a/mmeval/metrics/psnr.py +++ b/mmeval/metrics/psnr.py @@ -92,15 +92,22 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr crop_border=self.crop_border, input_order=self.input_order, convert_to=self.convert_to, - channel_order=self.channel_order) + channel_order=channel_order) prediction = reorder_and_crop( prediction, crop_border=self.crop_border, input_order=self.input_order, convert_to=self.convert_to, - channel_order=self.channel_order) - - self._results.append(self.compute_psnr(prediction, groundtruth)) + channel_order=channel_order) + + if len(prediction.shape) == 3: + prediction = np.expand_dims(prediction, axis=0) + groundtruth = np.expand_dims(groundtruth, axis=0) + _psnr_score = [] + for i in range(prediction.shape[0]): + _psnr_score.append( + self.compute_psnr(prediction[i], groundtruth[i])) + self._results.append(np.array(_psnr_score).mean()) def compute_metric(self, results: List[np.float64]) -> Dict[str, float]: """Compute the PeakSignalNoiseRatio metric. diff --git a/mmeval/metrics/snr.py b/mmeval/metrics/snr.py index 080be62e..beffbe1d 100644 --- a/mmeval/metrics/snr.py +++ b/mmeval/metrics/snr.py @@ -82,23 +82,31 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr """ channel_order = self.channel_order \ if channel_order is None else channel_order - for pred, gt in zip(predictions, groundtruths): - assert gt.shape == pred.shape, ( - f'Image shapes are different: {gt.shape}, {pred.shape}.') - gt = reorder_and_crop( - gt, + for prediction, groundtruth in zip(predictions, groundtruths): + assert groundtruth.shape == prediction.shape, ( + f'Image shapes are different: \ + {groundtruth.shape}, {prediction.shape}.') + groundtruth = reorder_and_crop( + groundtruth, crop_border=self.crop_border, input_order=self.input_order, convert_to=self.convert_to, - channel_order=self.channel_order) - pred = reorder_and_crop( - pred, + channel_order=channel_order) + prediction = reorder_and_crop( + prediction, crop_border=self.crop_border, input_order=self.input_order, convert_to=self.convert_to, - channel_order=self.channel_order) - - self._results.append(self.compute_snr(pred, gt)) + channel_order=channel_order) + + if len(prediction.shape) == 3: + prediction = np.expand_dims(prediction, axis=0) + groundtruth = np.expand_dims(groundtruth, axis=0) + _snr_score = [] + for i in range(prediction.shape[0]): + _snr_score.append( + self.compute_snr(prediction[i], groundtruth[i])) + self._results.append(np.array(_snr_score).mean()) def compute_metric(self, results: List[np.float64]) -> Dict[str, float]: """Compute the SignalNoiseRatio metric. diff --git a/mmeval/metrics/ssim.py b/mmeval/metrics/ssim.py index de60feaa..e507e212 100644 --- a/mmeval/metrics/ssim.py +++ b/mmeval/metrics/ssim.py @@ -119,9 +119,14 @@ def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarr convert_to=self.convert_to, channel_order=channel_order) + if len(pred.shape) == 3: + pred = np.expand_dims(pred, axis=0) + gt = np.expand_dims(gt, axis=0) _ssim_score = [] - for i in range(pred.shape[2]): - _ssim_score.append(self.compute_ssim(pred[..., i], gt[..., i])) + for i in range(pred.shape[0]): + for j in range(pred.shape[3]): + _ssim_score.append( + self.compute_ssim(pred[i][..., j], gt[i][..., j])) self._results.append(np.array(_ssim_score).mean()) def compute_metric(self, results: List[np.float64]) -> Dict[str, float]: diff --git a/mmeval/metrics/utils/image_transforms.py b/mmeval/metrics/utils/image_transforms.py index 33a89021..80cd176e 100644 --- a/mmeval/metrics/utils/image_transforms.py +++ b/mmeval/metrics/utils/image_transforms.py @@ -183,6 +183,14 @@ def reorder_and_crop(img: np.ndarray, np.array: The transformation results. """ + if len(img.shape) == 4: + result = [] + for i in range(img.shape[0]): + result.append( + reorder_and_crop(img[i], crop_border, input_order, convert_to, + channel_order)) + return np.array(result).astype(np.float64) + img = reorder_image(img, input_order=input_order) img = img.astype(np.float32) diff --git a/tests/test_metrics/test_mae.py b/tests/test_metrics/test_mae.py index 9947d923..fa11ac69 100644 --- a/tests/test_metrics/test_mae.py +++ b/tests/test_metrics/test_mae.py @@ -5,6 +5,7 @@ def test_mae(): + # test image input preds = [np.ones((32, 32, 3))] gts = [np.ones((32, 32, 3)) * 2] mask = np.ones((32, 32, 3)) * 2 @@ -19,3 +20,19 @@ def test_mae(): mae_results = mae(preds, gts, [mask]) assert isinstance(mae_results, dict) np.testing.assert_almost_equal(mae_results['mae'], 0.003921568627) + + # test video input + preds = [np.ones((5, 32, 32, 3))] + gts = [np.ones((5, 32, 32, 3)) * 2] + mask = np.ones((5, 32, 32, 3)) * 2 + mask[:, :16] *= 0 + + mae = MeanAbsoluteError() + mae_results = mae(preds, gts) + assert isinstance(mae_results, dict) + np.testing.assert_almost_equal(mae_results['mae'], 0.003921568627) + + mae = MeanAbsoluteError() + mae_results = mae(preds, gts, [mask]) + assert isinstance(mae_results, dict) + np.testing.assert_almost_equal(mae_results['mae'], 0.003921568627) diff --git a/tests/test_metrics/test_mse.py b/tests/test_metrics/test_mse.py index 1c08eea1..97d012b0 100644 --- a/tests/test_metrics/test_mse.py +++ b/tests/test_metrics/test_mse.py @@ -5,6 +5,7 @@ def test_mse(): + # test image input preds = [np.ones((32, 32, 3))] gts = [np.ones((32, 32, 3)) * 2] mask = np.ones((32, 32, 3)) * 2 @@ -19,3 +20,19 @@ def test_mse(): mse_results = mse(preds, gts) assert isinstance(mse_results, dict) np.testing.assert_almost_equal(mse_results['mse'], 0.000015378700496) + + # test video input + preds = [np.ones((5, 32, 32, 3))] + gts = [np.ones((5, 32, 32, 3)) * 2] + mask = np.ones((5, 32, 32, 3)) * 2 + mask[:, :16] *= 0 + + mse = MeanSquaredError() + mse_results = mse(preds, gts) + assert isinstance(mse_results, dict) + np.testing.assert_almost_equal(mse_results['mse'], 0.000015378700496) + + mse = MeanSquaredError() + mse_results = mse(preds, gts) + assert isinstance(mse_results, dict) + np.testing.assert_almost_equal(mse_results['mse'], 0.000015378700496) diff --git a/tests/test_metrics/test_psnr.py b/tests/test_metrics/test_psnr.py index 7efcd1f8..b9f15a7b 100644 --- a/tests/test_metrics/test_psnr.py +++ b/tests/test_metrics/test_psnr.py @@ -53,7 +53,9 @@ def test_psnr_init(): }, [np.ones((32, 32, 3))], [np.ones( (32, 32, 3)) * 2], 49.45272242415597), ({}, [np.ones((32, 32))], [np.ones((32, 32))], float('inf')), - ({}, [np.zeros((32, 32))], [np.ones((32, 32)) * 255], 0)]) + ({}, [np.zeros((32, 32))], [np.ones((32, 32)) * 255], 0), + ({}, [np.ones((5, 3, 32, 32))], [np.ones( + (5, 3, 32, 32)) * 2], 48.1308036086791)]) def test_psnr(metric_kwargs, img1, img2, results): psnr = PeakSignalNoiseRatio(**metric_kwargs) psnr_results = psnr(img1, img2) diff --git a/tests/test_metrics/test_snr.py b/tests/test_metrics/test_snr.py index 0c2575be..15a1c29b 100644 --- a/tests/test_metrics/test_snr.py +++ b/tests/test_metrics/test_snr.py @@ -39,7 +39,9 @@ def test_snr_init(): ({'input_order': 'HWC', 'convert_to': 'Y'}, [np.ones((32, 32, 3))], [np.ones((32, 32, 3)) * 2], 26.290039980499536), ({}, [np.ones((32, 32))], [np.ones((32, 32))], float('inf')), - ({}, [np.zeros((32, 32))], [np.ones((32, 32))], 0) + ({}, [np.zeros((32, 32))], [np.ones((32, 32))], 0), + ({}, [np.ones((5, 3, 32, 32))], [np.ones((5, 3, 32, 32)) * 2], + 6.020599913279624) ] ) def test_snr(metric_kwargs, img1, img2, results): diff --git a/tests/test_metrics/test_ssim.py b/tests/test_metrics/test_ssim.py index 9fde822b..ac2a1606 100644 --- a/tests/test_metrics/test_ssim.py +++ b/tests/test_metrics/test_ssim.py @@ -37,6 +37,8 @@ def test_ssim_init(): [np.ones((3, 32, 32)) * 2], 0.9130623), ({'convert_to': 'Y', 'input_order': 'HWC'}, [np.ones((32, 32, 3))], [np.ones((32, 32, 3)) * 2], 0.9987801), + ({}, [np.ones((5, 3, 32, 32))], [np.ones((5, 3, 32, 32)) * 2], + 0.9130623) ] ) def test_ssim(metric_kwargs, img1, img2, results):