|
| 1 | +# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py |
| 2 | + |
1 | 3 | import torch
|
2 | 4 | import torch.nn as nn
|
3 | 5 | import torch.nn.functional as F
|
@@ -84,59 +86,78 @@ def __init__(
|
84 | 86 | self.keep_batch_dim = keep_batch_dim
|
85 | 87 | self.return_log = return_log
|
86 | 88 |
|
87 |
| - self.gaussian_filer = GaussianFilter2D(window_size=window_size, in_channels=in_channels, sigma=sigma) |
| 89 | + self.gaussian_filter = GaussianFilter2D(window_size=window_size, in_channels=in_channels, sigma=sigma) |
88 | 90 |
|
| 91 | + @torch.cuda.amp.autocast(enabled=False) |
89 | 92 | def forward(self, x, y):
|
90 |
| - return ssim( |
91 |
| - x, |
92 |
| - y, |
93 |
| - gaussian_filter=self.gaussian_filer, |
94 |
| - C1=self.C1, |
95 |
| - C2=self.C2, |
96 |
| - keep_batch_dim=self.keep_batch_dim, |
97 |
| - return_log=self.return_log, |
98 |
| - ) |
| 93 | + """Calculate the mean SSIM (MSSIM) between two 4d tensors. |
99 | 94 |
|
| 95 | + Args: |
| 96 | + x (Tensor): 4d tensor |
| 97 | + y (Tensor): 4d tensor |
100 | 98 |
|
101 |
| -@torch.cuda.amp.autocast(enabled=False) |
102 |
| -def ssim(x, y, gaussian_filter, C1, C2, keep_batch_dim=False, return_log=False): |
103 |
| - """Calculate the mean SSIM (MSSIM) between two 4d tensors. |
| 99 | + Returns: |
| 100 | + Tensor: MSSIM |
| 101 | + """ |
| 102 | + assert x.shape == y.shape, f"x: {x.shape} and y: {y.shape} must be the same" |
| 103 | + assert x.ndim == y.ndim == 4, f"x: {x.ndim} and y: {y.ndim} must be 4" |
| 104 | + assert ( |
| 105 | + x.type() == y.type() == self.gaussian_filter.gaussian_window2d.type() |
| 106 | + ), f"x: {x.type()} and y: {y.type()} must be {self.gaussian_filter.gaussian_window2d.type()}" |
| 107 | + |
| 108 | + mu_x = self.gaussian_filter(x) # equ 14 |
| 109 | + mu_y = self.gaussian_filter(y) # equ 14 |
| 110 | + sigma2_x = self.gaussian_filter(x * x) - mu_x * mu_x # equ 15 |
| 111 | + sigma2_y = self.gaussian_filter(y * y) - mu_y * mu_y # equ 15 |
| 112 | + sigma_xy = self.gaussian_filter(x * y) - mu_x * mu_y # equ 16 |
| 113 | + |
| 114 | + # equ 13 in ref1 |
| 115 | + A1 = 2 * mu_x * mu_y + self.C1 |
| 116 | + A2 = 2 * sigma_xy + self.C2 |
| 117 | + B1 = mu_x * mu_x + mu_y * mu_y + self.C1 |
| 118 | + B2 = sigma2_x + sigma2_y + self.C2 |
| 119 | + S = (A1 * A2) / (B1 * B2) |
| 120 | + |
| 121 | + if self.return_log: |
| 122 | + S = S - S.min() |
| 123 | + S = S / S.max() |
| 124 | + S = -torch.log(S + 1e-8) |
| 125 | + |
| 126 | + if self.keep_batch_dim: |
| 127 | + return S.mean(dim=(1, 2, 3)) |
| 128 | + else: |
| 129 | + return S.mean() |
| 130 | + |
| 131 | + |
| 132 | +def ssim( |
| 133 | + x, y, *, window_size=11, in_channels=1, sigma=1.5, K1=0.01, K2=0.03, L=1, keep_batch_dim=False, return_log=False |
| 134 | +): |
| 135 | + """Calculate the mean SSIM (MSSIM) between two 4D tensors. |
104 | 136 |
|
105 | 137 | Args:
|
106 | 138 | x (Tensor): 4d tensor
|
107 | 139 | y (Tensor): 4d tensor
|
108 |
| - gaussian_filter (GaussianFilter2D): the gaussian filter object |
109 |
| - C1 (float): the constant to avoid instability |
110 |
| - C2 (float): the constant to avoid instability |
| 140 | + window_size (int, optional): The window size of the gaussian filter. Defaults to 11. |
| 141 | + in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False. |
| 142 | + sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5. |
| 143 | + K1 (float, optional): K1 of MSSIM. Defaults to 0.01. |
| 144 | + K2 (float, optional): K2 of MSSIM. Defaults to 0.03. |
| 145 | + L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1. |
111 | 146 | keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False.
|
112 | 147 | return_log (bool, optional): Whether to return the logarithmic form. Defaults to False.
|
113 | 148 |
|
| 149 | +
|
114 | 150 | Returns:
|
115 | 151 | Tensor: MSSIM
|
116 | 152 | """
|
117 |
| - assert x.shape == y.shape, f"x: {x.shape} != y: {y.shape}" |
118 |
| - assert x.ndim == y.ndim == 4, f"x: {x.ndim} != y: {y.ndim} != 4" |
119 |
| - assert x.type() == y.type(), f"x: {x.type()} != y: {y.type()}" |
120 |
| - |
121 |
| - mu_x = gaussian_filter(x) # equ 14 |
122 |
| - mu_y = gaussian_filter(y) # equ 14 |
123 |
| - sigma2_x = gaussian_filter(x * x) - mu_x * mu_x # equ 15 |
124 |
| - sigma2_y = gaussian_filter(y * y) - mu_y * mu_y # equ 15 |
125 |
| - sigma_xy = gaussian_filter(x * y) - mu_x * mu_y # equ 16 |
126 |
| - |
127 |
| - # equ 13 in ref1 |
128 |
| - A1 = 2 * mu_x * mu_y + C1 |
129 |
| - A2 = 2 * sigma_xy + C2 |
130 |
| - B1 = mu_x * mu_x + mu_y * mu_y + C1 |
131 |
| - B2 = sigma2_x + sigma2_y + C2 |
132 |
| - S = (A1 * A2) / (B1 * B2) |
133 |
| - |
134 |
| - if return_log: |
135 |
| - S = S - S.min() |
136 |
| - S = S / S.max() |
137 |
| - S = -torch.log(S + 1e-8) |
138 |
| - |
139 |
| - if keep_batch_dim: |
140 |
| - return S.mean(dim=(1, 2, 3)) |
141 |
| - else: |
142 |
| - return S.mean() |
| 153 | + ssim_obj = SSIM( |
| 154 | + window_size=window_size, |
| 155 | + in_channels=in_channels, |
| 156 | + sigma=sigma, |
| 157 | + K1=K1, |
| 158 | + K2=K2, |
| 159 | + L=L, |
| 160 | + keep_batch_dim=keep_batch_dim, |
| 161 | + return_log=return_log, |
| 162 | + ) |
| 163 | + return ssim_obj(x, y) |
0 commit comments