Skip to content

Commit f656e5a

Browse files
authored
Fix error in ergas calculation (#2498)
* fix error in formula * fix doctests * changelog * fix other doctests
1 parent 6e088fe commit f656e5a

File tree

6 files changed

+11
-9
lines changed

6 files changed

+11
-9
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))
3737

3838

39+
- Fixed bug in computation of `ERGAS` metric ([#2498](https://github.com/Lightning-AI/torchmetrics/pull/2498))
40+
41+
3942
- Fixed `BootStrapper` wrapper not working with `kwargs` provided argument ([#2503](https://github.com/Lightning-AI/torchmetrics/pull/2503))
4043

4144

src/torchmetrics/functional/image/_deprecated.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _error_relative_global_dimensionless_synthesis(
5353
>>> target = preds * 0.75
5454
>>> ergds = _error_relative_global_dimensionless_synthesis(preds, target)
5555
>>> torch.round(ergds)
56-
tensor(154.)
56+
tensor(10.)
5757
5858
"""
5959
_deprecated_root_import_func("error_relative_global_dimensionless_synthesis", "image")

src/torchmetrics/functional/image/ergas.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _ergas_compute(
6767
>>> target = preds * 0.75
6868
>>> preds, target = _ergas_update(preds, target)
6969
>>> torch.round(_ergas_compute(preds, target))
70-
tensor(154.)
70+
tensor(10.)
7171
7272
"""
7373
b, c, h, w = preds.shape
@@ -79,7 +79,7 @@ def _ergas_compute(
7979
rmse_per_band = torch.sqrt(sum_squared_error / (h * w))
8080
mean_target = torch.mean(target, dim=2)
8181

82-
ergas_score = 100 * ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c)
82+
ergas_score = 100 / ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c)
8383
return reduce(ergas_score, reduction)
8484

8585

@@ -115,9 +115,8 @@ def error_relative_global_dimensionless_synthesis(
115115
>>> gen = torch.manual_seed(42)
116116
>>> preds = torch.rand([16, 1, 16, 16], generator=gen)
117117
>>> target = preds * 0.75
118-
>>> ergds = error_relative_global_dimensionless_synthesis(preds, target)
119-
>>> torch.round(ergds)
120-
tensor(154.)
118+
>>> error_relative_global_dimensionless_synthesis(preds, target)
119+
tensor(9.6193)
121120
122121
"""
123122
preds, target = _ergas_update(preds, target)

src/torchmetrics/image/_deprecated.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionles
2222
>>> target = preds * 0.75
2323
>>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis()
2424
>>> torch.round(ergas(preds, target))
25-
tensor(154.)
25+
tensor(10.)
2626
2727
"""
2828

src/torchmetrics/image/ergas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric):
6969
>>> target = preds * 0.75
7070
>>> ergas = ErrorRelativeGlobalDimensionlessSynthesis()
7171
>>> torch.round(ergas(preds, target))
72-
tensor(154.)
72+
tensor(10.)
7373
7474
"""
7575

tests/unittests/image/test_ergas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _reference_ergas(
6565
rmse_per_band = torch.sqrt(sum_squared_error / (h * w))
6666
mean_target = torch.mean(sk_target, dim=2)
6767
# compute ergas score
68-
ergas_score = 100 * ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c)
68+
ergas_score = 100 / ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c)
6969
# reduction
7070
if reduction == "sum":
7171
return torch.sum(ergas_score)

0 commit comments

Comments
 (0)