-
Notifications
You must be signed in to change notification settings - Fork 0
/
_loss.py
51 lines (35 loc) · 1.37 KB
/
_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
class CustomLoss(object):
#### user can custom weight ####
def __init__(self, w1, w2):
self.w1 = w1
self.w2 = w2
self.eps = 1e-7
def custom_loss_1(self, output, target):
### baseline is the first value of target variable ###
outsgn = torch.sign(output - target[0] * output.shape[1])
trgsgn = torch.sign(target - target[0] * output.shape[1])
binary = outsgn != trgsgn
w1 = torch.ones(*binary.shape) * self.w1
w2 = torch.ones(*binary.shape) * self.w2
w1 = w1.to("cuda")
w2 = w2.to("cuda")
binary = torch.where(binary < self.eps, w1, w2)
se = (output - target) ** 2
loss = torch.mean(se * binary)
return loss
def custom_loss_2(self, output, target):
### baseline is the previous value of target variable ###
zero = torch.zeros(1, output.shape[1]).to("cuda")
binary = torch.sign(output[1:] - target[:-1]) != torch.sign(
target[1:] - target[:-1]
)
binary = torch.cat((zero, binary))
w1 = torch.ones(*binary.shape) * self.w1
w2 = torch.ones(*binary.shape) * self.w2
w1 = w1.to("cuda")
w2 = w2.to("cuda")
binary = torch.where(binary < self.eps, w1, w2)
se = (output - target) ** 2
loss = torch.mean(se * binary)
return loss