-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
66 lines (48 loc) · 2.14 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from torch import nn as nn
from Utils.spectral_tools import gen_mtf, mtf_kernel_to_torch
class SpectralLoss(nn.Module):
def __init__(self, ratio=2):
super(SpectralLoss, self).__init__()
self.ratio = ratio
nbands = 6
h_bl = gen_mtf(ratio, sensor='S2-20', kernel_size=9)
h_bl = mtf_kernel_to_torch(h_bl)
self.depthconv_20 = nn.Conv2d(in_channels=nbands,
out_channels=nbands,
groups=nbands,
padding='same',
padding_mode='replicate',
kernel_size=h_bl.shape[-1],
bias=False)
self.depthconv_20.weight.data = h_bl
self.depthconv_20.weight.requires_grad = False
self.avgpool = nn.AvgPool2d(kernel_size=self.ratio)
self.loss = nn.L1Loss(reduction='mean')
def forward(self, outputs, labels):
outputs_lp = self.depthconv_20(outputs)
outputs_lr = self.avgpool(outputs_lp)
L = self.loss(outputs_lr[:, :, 9:-9, 9:-9], labels[:, :, 9:-9, 9:-9])
return L
class StructLoss(nn.Module):
def __init__(self, ratio=2):
super(StructLoss, self).__init__()
self.ratio = ratio
nbands = 6
h_bl = gen_mtf(ratio, sensor='S2-20', kernel_size=9)
h_bl = mtf_kernel_to_torch(h_bl)
self.depthconv_20 = nn.Conv2d(in_channels=nbands,
out_channels=nbands,
groups=nbands,
padding='same',
padding_mode='replicate',
kernel_size=h_bl.shape[-1],
bias=False)
self.depthconv_20.weight.data = h_bl
self.depthconv_20.weight.requires_grad = False
self.loss = nn.L1Loss(reduction='mean')
def forward(self, outputs, labels):
outputs_lp = self.depthconv_20(outputs)
outputs_hr = outputs - outputs_lp
L = self.loss(outputs_hr[:, :, 9:-9, 9:-9], labels[:, :, 9:-9, 9:-9])
return L