@@ -67,7 +67,7 @@ def _ergas_compute(
67
67
>>> target = preds * 0.75
68
68
>>> preds, target = _ergas_update(preds, target)
69
69
>>> torch.round(_ergas_compute(preds, target))
70
- tensor(154 .)
70
+ tensor(10 .)
71
71
72
72
"""
73
73
b , c , h , w = preds .shape
@@ -79,7 +79,7 @@ def _ergas_compute(
79
79
rmse_per_band = torch .sqrt (sum_squared_error / (h * w ))
80
80
mean_target = torch .mean (target , dim = 2 )
81
81
82
- ergas_score = 100 * ratio * torch .sqrt (torch .sum ((rmse_per_band / mean_target ) ** 2 , dim = 1 ) / c )
82
+ ergas_score = 100 / ratio * torch .sqrt (torch .sum ((rmse_per_band / mean_target ) ** 2 , dim = 1 ) / c )
83
83
return reduce (ergas_score , reduction )
84
84
85
85
@@ -115,9 +115,8 @@ def error_relative_global_dimensionless_synthesis(
115
115
>>> gen = torch.manual_seed(42)
116
116
>>> preds = torch.rand([16, 1, 16, 16], generator=gen)
117
117
>>> target = preds * 0.75
118
- >>> ergds = error_relative_global_dimensionless_synthesis(preds, target)
119
- >>> torch.round(ergds)
120
- tensor(154.)
118
+ >>> error_relative_global_dimensionless_synthesis(preds, target)
119
+ tensor(9.6193)
121
120
122
121
"""
123
122
preds , target = _ergas_update (preds , target )
0 commit comments