-
-
Notifications
You must be signed in to change notification settings - Fork 653
Fix simplify ssim implementation #2563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix simplify ssim implementation #2563
Conversation
…/sadra-barikbin/ignite into Fix-simplify-SSIM-implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sadra-barikbin , lgtm!
@sadra-barikbin I had to run in on cuda to make sure :( |
@@ -98,8 +98,7 @@ def __init__( | |||
|
|||
@reinit__is_reduced | |||
def reset(self) -> None: | |||
# Not a tensor because batch size is not known in advance. | |||
self._sum_of_batchwise_ssim = 0.0 # type: Union[float, torch.Tensor] | |||
self._sum_of_ssim = torch.tensor(0.0, device=self._device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be float64
self._num_examples += y.shape[0] | ||
|
||
@sync_all_reduce("_sum_of_batchwise_ssim", "_num_examples") | ||
@sync_all_reduce("_sum_of_ssim", "_num_examples") | ||
def compute(self) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's return just python float instead of a tensor ?
assert isinstance(ignite_ssim, torch.Tensor) | ||
assert ignite_ssim.dtype == torch.float64 | ||
assert ignite_ssim.device == torch.device(device) | ||
assert np.allclose(ignite_ssim.cpu().numpy(), skimg_ssim, atol=atol) | ||
|
||
|
||
def test_ssim(): | ||
device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
y_pred = torch.rand(8, 3, 224, 224, device=device) | ||
y = y_pred * 0.8 | ||
_test_ssim( | ||
y_pred, y, data_range=1.0, kernel_size=7, sigma=1.5, gaussian=False, use_sample_covariance=True, device=device | ||
) | ||
assert ignite_ssim.device.type == torch.device(device).type | ||
assert np.allclose(ignite_ssim.cpu().numpy(), skimg_ssim, atol=7e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we can return float instead of Tensor, let's update tests
Fixes #2532
Description:
Based on #2562 ,
update
method became simpler.Check list: