-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
42 lines (33 loc) · 1.38 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import nn
"""
This file contains the module classes that implement loss functions
in this small framework. Each loss function extends the base module and hence has
a forward() and backward() method. The following loss functions are implemented
currently in the framework:
- Mean Squared Error (MSE)
Authors: Albergoni, Bouvet, Feo
"""
"""
This module implements the Mean Squared Error (MSE) loss function. It's forward method
can compute the loss on batches of predictions compared to true targets, aggregating the result
either by summing or averaging. The backward method instead computes the gradient of the
loss function wrt to predictions.
"""
class LossMSE(nn.Module):
def __init__(self):
super(LossMSE, self).__init__()
self._supported_reductions = ['sum', 'mean']
def _check_reduction(self, reduction):
if reduction not in self._supported_reductions:
raise ValueError("This loss reduction is not supported : %s" % reduction)
def forward(self, x, t, reduction='sum'):
self._check.shapes_match(x, t)
self._check_reduction(reduction)
if reduction == 'sum':
return (x - t).pow(2).sum()
elif reduction == 'mean':
return (x - t).pow(2).sum(1).mean()
def backward(self, x, t):
self._check.shapes_match(x, t)
return 2 * (x - t)