Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for custom prefix/postfix and metric collection #2070

Merged
merged 7 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))


- Fixed bug related to `MetricCollection` used with custom metrics have `prefix`/`postfix` attributes ([#2070](https://github.com/Lightning-AI/torchmetrics/pull/2070))

## [1.1.1] - 2023-08-29

### Added
Expand Down
9 changes: 6 additions & 3 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,17 +339,18 @@ def _compute_and_reduce(
_, duplicates = _flatten_dict(result)

flattened_results = {}
for k, res in result.items():
for k, m in self.items(keep_base=True, copy_state=False):
res = result[k]
if isinstance(res, dict):
for key, v in res.items():
# if duplicates of keys we need to add unique prefix to each key
if duplicates:
stripped_k = k.replace(getattr(m, "prefix", ""), "")
stripped_k = stripped_k.replace(getattr(m, "postfix", ""), "")
key = f"{stripped_k}_{key}"
if hasattr(m, "prefix") and m.prefix is not None:
if getattr(m, "_from_collection", None) and m.prefix is not None:
key = f"{m.prefix}{key}"
if hasattr(m, "postfix") and m.postfix is not None:
if getattr(m, "_from_collection", None) and m.postfix is not None:
key = f"{key}{m.postfix}"
flattened_results[key] = v
else:
Expand Down Expand Up @@ -425,6 +426,7 @@ def add_metrics(
for k, v in metric.items(keep_base=False):
v.postfix = metric.postfix
v.prefix = metric.prefix
v._from_collection = True
self[f"{name}_{k}"] = v
elif isinstance(metrics, Sequence):
for metric in metrics:
Expand All @@ -442,6 +444,7 @@ def add_metrics(
for k, v in metric.items(keep_base=False):
v.postfix = metric.postfix
v.prefix = metric.prefix
v._from_collection = True
self[k] = v
else:
raise ValueError(
Expand Down
33 changes: 33 additions & 0 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,36 @@ def test_double_nested_collections(base_metrics, expected):

for key in val:
assert key in expected


def test_with_custom_prefix_postfix():
"""Test that metric colection does not clash with custom prefix and postfix in users metrics.

See issue: https://github.com/Lightning-AI/torchmetrics/issues/2065

"""

class CustomAccuracy(MulticlassAccuracy):
prefix = "my_prefix"
postfix = "my_postfix"

def compute(self):
value = super().compute()
return {f"{self.prefix}/accuracy/{self.postfix}": value}

class CustomPrecision(MulticlassAccuracy):
prefix = "my_prefix"
postfix = "my_postfix"

def compute(self):
value = super().compute()
return {f"{self.prefix}/precision/{self.postfix}": value}

metrics = MetricCollection([CustomAccuracy(num_classes=2), CustomPrecision(num_classes=2)])

# Update metrics with current batch
res = metrics(torch.tensor([1, 0, 0, 1]), torch.tensor([1, 0, 0, 0]))

# Print the calculated metrics
assert "my_prefix/accuracy/my_postfix" in res
assert "my_prefix/precision/my_postfix" in res
Loading