Skip to content

Commit 0095ee2

Browse files
authored
Update ssim.py
1 parent 409f601 commit 0095ee2

File tree

1 file changed

+63
-42
lines changed

1 file changed

+63
-42
lines changed

ssim.py

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
2+
13
import torch
24
import torch.nn as nn
35
import torch.nn.functional as F
@@ -84,59 +86,78 @@ def __init__(
8486
self.keep_batch_dim = keep_batch_dim
8587
self.return_log = return_log
8688

87-
self.gaussian_filer = GaussianFilter2D(window_size=window_size, in_channels=in_channels, sigma=sigma)
89+
self.gaussian_filter = GaussianFilter2D(window_size=window_size, in_channels=in_channels, sigma=sigma)
8890

91+
@torch.cuda.amp.autocast(enabled=False)
8992
def forward(self, x, y):
90-
return ssim(
91-
x,
92-
y,
93-
gaussian_filter=self.gaussian_filer,
94-
C1=self.C1,
95-
C2=self.C2,
96-
keep_batch_dim=self.keep_batch_dim,
97-
return_log=self.return_log,
98-
)
93+
"""Calculate the mean SSIM (MSSIM) between two 4d tensors.
9994
95+
Args:
96+
x (Tensor): 4d tensor
97+
y (Tensor): 4d tensor
10098
101-
@torch.cuda.amp.autocast(enabled=False)
102-
def ssim(x, y, gaussian_filter, C1, C2, keep_batch_dim=False, return_log=False):
103-
"""Calculate the mean SSIM (MSSIM) between two 4d tensors.
99+
Returns:
100+
Tensor: MSSIM
101+
"""
102+
assert x.shape == y.shape, f"x: {x.shape} and y: {y.shape} must be the same"
103+
assert x.ndim == y.ndim == 4, f"x: {x.ndim} and y: {y.ndim} must be 4"
104+
assert (
105+
x.type() == y.type() == self.gaussian_filter.gaussian_window2d.type()
106+
), f"x: {x.type()} and y: {y.type()} must be {self.gaussian_filter.gaussian_window2d.type()}"
107+
108+
mu_x = self.gaussian_filter(x) # equ 14
109+
mu_y = self.gaussian_filter(y) # equ 14
110+
sigma2_x = self.gaussian_filter(x * x) - mu_x * mu_x # equ 15
111+
sigma2_y = self.gaussian_filter(y * y) - mu_y * mu_y # equ 15
112+
sigma_xy = self.gaussian_filter(x * y) - mu_x * mu_y # equ 16
113+
114+
# equ 13 in ref1
115+
A1 = 2 * mu_x * mu_y + self.C1
116+
A2 = 2 * sigma_xy + self.C2
117+
B1 = mu_x * mu_x + mu_y * mu_y + self.C1
118+
B2 = sigma2_x + sigma2_y + self.C2
119+
S = (A1 * A2) / (B1 * B2)
120+
121+
if self.return_log:
122+
S = S - S.min()
123+
S = S / S.max()
124+
S = -torch.log(S + 1e-8)
125+
126+
if self.keep_batch_dim:
127+
return S.mean(dim=(1, 2, 3))
128+
else:
129+
return S.mean()
130+
131+
132+
def ssim(
133+
x, y, *, window_size=11, in_channels=1, sigma=1.5, K1=0.01, K2=0.03, L=1, keep_batch_dim=False, return_log=False
134+
):
135+
"""Calculate the mean SSIM (MSSIM) between two 4D tensors.
104136
105137
Args:
106138
x (Tensor): 4d tensor
107139
y (Tensor): 4d tensor
108-
gaussian_filter (GaussianFilter2D): the gaussian filter object
109-
C1 (float): the constant to avoid instability
110-
C2 (float): the constant to avoid instability
140+
window_size (int, optional): The window size of the gaussian filter. Defaults to 11.
141+
in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False.
142+
sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5.
143+
K1 (float, optional): K1 of MSSIM. Defaults to 0.01.
144+
K2 (float, optional): K2 of MSSIM. Defaults to 0.03.
145+
L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1.
111146
keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False.
112147
return_log (bool, optional): Whether to return the logarithmic form. Defaults to False.
113148
149+
114150
Returns:
115151
Tensor: MSSIM
116152
"""
117-
assert x.shape == y.shape, f"x: {x.shape} != y: {y.shape}"
118-
assert x.ndim == y.ndim == 4, f"x: {x.ndim} != y: {y.ndim} != 4"
119-
assert x.type() == y.type(), f"x: {x.type()} != y: {y.type()}"
120-
121-
mu_x = gaussian_filter(x) # equ 14
122-
mu_y = gaussian_filter(y) # equ 14
123-
sigma2_x = gaussian_filter(x * x) - mu_x * mu_x # equ 15
124-
sigma2_y = gaussian_filter(y * y) - mu_y * mu_y # equ 15
125-
sigma_xy = gaussian_filter(x * y) - mu_x * mu_y # equ 16
126-
127-
# equ 13 in ref1
128-
A1 = 2 * mu_x * mu_y + C1
129-
A2 = 2 * sigma_xy + C2
130-
B1 = mu_x * mu_x + mu_y * mu_y + C1
131-
B2 = sigma2_x + sigma2_y + C2
132-
S = (A1 * A2) / (B1 * B2)
133-
134-
if return_log:
135-
S = S - S.min()
136-
S = S / S.max()
137-
S = -torch.log(S + 1e-8)
138-
139-
if keep_batch_dim:
140-
return S.mean(dim=(1, 2, 3))
141-
else:
142-
return S.mean()
153+
ssim_obj = SSIM(
154+
window_size=window_size,
155+
in_channels=in_channels,
156+
sigma=sigma,
157+
K1=K1,
158+
K2=K2,
159+
L=L,
160+
keep_batch_dim=keep_batch_dim,
161+
return_log=return_log,
162+
)
163+
return ssim_obj(x, y)

0 commit comments

Comments
 (0)