-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathmetrics.py
80 lines (67 loc) · 2.31 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Weighted R² functions."""
import numpy as np
import torch
import torch.nn as nn
def r2_weighted(
y_true: np.array,
y_pred: np.array,
sample_weight: np.array
) -> float:
"""Compute the weighted R² score.
Args:
y_true (np.array): Ground truth values.
y_pred (np.array): Predicted values.
sample_weight (np.array): Weights for each observation.
Returns:
float: Weighted R² score.
"""
r2 = 1 - np.average((y_pred - y_true) ** 2, weights=sample_weight) / (
np.average((y_true) ** 2, weights=sample_weight) + 1e-38
)
return r2
def r2_weighted_torch(
y_true: torch.Tensor,
y_pred: torch.Tensor,
sample_weight: torch.Tensor
) -> torch.Tensor:
"""Compute the weighted R² score using PyTorch tensors.
Args:
y_true (torch.Tensor): Ground truth tensor.
y_pred (torch.Tensor): Predicted tensor.
sample_weight (torch.Tensor): Weights for each observation (same shape as y_true).
Returns:
torch.Tensor: Weighted R² score.
"""
numerator = torch.sum(sample_weight * (y_pred - y_true) ** 2)
denominator = torch.sum(sample_weight * (y_true) ** 2) + 1e-38
r2 = 1 - (numerator / denominator)
return r2
class WeightedR2Loss(nn.Module):
"""PyTorch loss function for weighted R²."""
def __init__(self, epsilon: float = 1e-38) -> None:
"""
Initialize the WeightedR2Loss class.
Args:
epsilon (float, optional): Small constant added to the denominator
for numerical stability. Defaults to 1e-38.
"""
super(WeightedR2Loss, self).__init__()
self.epsilon = epsilon
def forward(
self,
y_pred: torch.Tensor,
y_true: torch.Tensor,
weights: torch.Tensor
) -> torch.Tensor:
"""Compute the weighted R² loss.
Args:
y_true (torch.Tensor): Ground truth tensor.
y_pred (torch.Tensor): Predicted tensor.
weights (torch.Tensor): Weights for each observation (same shape as y_true).
Returns:
torch.Tensor: Computed weighted R² loss.
"""
numerator = torch.sum(weights * (y_pred - y_true) ** 2)
denominator = torch.sum(weights * (y_true) ** 2) + 1e-38
loss = numerator / denominator
return loss