forked from fxrshed/ScaledSPS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_fns.py
40 lines (29 loc) · 1.04 KB
/
loss_fns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from distutils.log import Log
import torch
def logistic_reg(w, X, y):
return torch.mean(torch.log(1 + torch.exp(-y * (X @ w))))
def nllsq(w, X, y):
return torch.mean( ( y - (1/(1 + torch.exp(-X @ w ))) )**2 )
def get_loss(loss):
if loss == "logreg":
return LogisticRegression
elif loss == "nllsq":
return NLLSQ
elif loss == "nll_loss":
return torch.functional.F.nll_loss
else:
raise ValueError("Non-existent loss requested.")
class LogisticRegression(torch.nn.Module):
y_range = torch.tensor([-1., 1.])
def __init__(self, params):
self.params = params
super().__init__()
def forward(self, input, target):
return torch.mean(torch.log(1 + torch.exp(-target * (input @ self.params))))
class NLLSQ(torch.nn.Module):
y_range = torch.tensor([0., 1.])
def __init__(self, params):
self.params = params
super().__init__()
def forward(self, input, target):
return torch.mean((target - (1/(1 + torch.exp(-input @ self.params))))**2)