Skip to content

Commit c94e21a

Browse files
SkafteNickiBorda
andauthored
Fix multitask wrapper not being logged in lightning when used together with collections (Lightning-AI#2349)
* integration tests * implementation * better testing * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
1 parent 9253717 commit c94e21a

File tree

4 files changed

+115
-15
lines changed

4 files changed

+115
-15
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3030

3131
### Fixed
3232

33+
- Fixed `MultitaskWrapper` not being able to be logged in lightning when using metric collections ([#2349](https://github.com/Lightning-AI/torchmetrics/pull/2349))
34+
35+
3336
- Fixed high memory consumption in `Perplexity` metric ([#2346](https://github.com/Lightning-AI/torchmetrics/pull/2346))
3437

3538

src/torchmetrics/wrappers/multitask.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,49 @@ def __init__(
103103
super().__init__()
104104
self.task_metrics = nn.ModuleDict(task_metrics)
105105

106-
def items(self) -> Iterable[Tuple[str, nn.Module]]:
107-
"""Iterate over task and task metrics."""
108-
return self.task_metrics.items()
106+
def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]:
107+
"""Iterate over task and task metrics.
109108
110-
def keys(self) -> Iterable[str]:
111-
"""Iterate over task names."""
112-
return self.task_metrics.keys()
109+
Args:
110+
flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
111+
If False, will iterate over the task names and the corresponding metrics.
112+
113+
"""
114+
for task_name, metric in self.task_metrics.items():
115+
if flatten and isinstance(metric, MetricCollection):
116+
for sub_metric_name, sub_metric in metric.items():
117+
yield f"{task_name}_{sub_metric_name}", sub_metric
118+
else:
119+
yield task_name, metric
120+
121+
def keys(self, flatten: bool = True) -> Iterable[str]:
122+
"""Iterate over task names.
123+
124+
Args:
125+
flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
126+
If False, will iterate over the task names and the corresponding metrics.
127+
128+
"""
129+
for task_name, metric in self.task_metrics.items():
130+
if flatten and isinstance(metric, MetricCollection):
131+
for sub_metric_name in metric:
132+
yield f"{task_name}_{sub_metric_name}"
133+
else:
134+
yield task_name
113135

114-
def values(self) -> Iterable[nn.Module]:
115-
"""Iterate over task metrics."""
116-
return self.task_metrics.values()
136+
def values(self, flatten: bool = True) -> Iterable[nn.Module]:
137+
"""Iterate over task metrics.
138+
139+
Args:
140+
flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
141+
If False, will iterate over the task names and the corresponding metrics.
142+
143+
"""
144+
for metric in self.task_metrics.values():
145+
if flatten and isinstance(metric, MetricCollection):
146+
yield from metric.values()
147+
else:
148+
yield metric
117149

118150
@staticmethod
119151
def _check_task_metrics_type(task_metrics: Dict[str, Union[Metric, MetricCollection]]) -> None:

tests/integrations/test_lightning.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torchmetrics import MetricCollection
2929
from torchmetrics.aggregation import SumMetric
3030
from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision
31-
from torchmetrics.regression import MeanSquaredError
31+
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
3232
from torchmetrics.wrappers import MultitaskWrapper
3333

3434
from integrations.helpers import no_warning_call
@@ -366,22 +366,34 @@ def test_task_wrapper_lightning_logging(tmpdir):
366366
class TestModel(BoringModel):
367367
def __init__(self) -> None:
368368
super().__init__()
369-
self.metric = MultitaskWrapper({"classification": BinaryAccuracy(), "regression": MeanSquaredError()})
369+
self.multitask = MultitaskWrapper({"classification": BinaryAccuracy(), "regression": MeanSquaredError()})
370+
self.multitask_collection = MultitaskWrapper(
371+
{
372+
"classification": MetricCollection([BinaryAccuracy(), BinaryAveragePrecision()]),
373+
"regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
374+
}
375+
)
376+
370377
self.accuracy = BinaryAccuracy()
371378
self.mse = MeanSquaredError()
372379

373380
def training_step(self, batch, batch_idx):
374381
preds = torch.rand(10)
375382
target = torch.rand(10)
376-
self.metric(
377-
{"classification": preds.round(), "regression": preds},
378-
{"classification": target.round(), "regression": target},
383+
self.multitask(
384+
{"classification": preds, "regression": preds},
385+
{"classification": target.round().int(), "regression": target},
386+
)
387+
self.multitask_collection(
388+
{"classification": preds, "regression": preds},
389+
{"classification": target.round().int(), "regression": target},
379390
)
380391
self.accuracy(preds.round(), target.round())
381392
self.mse(preds, target)
382393
self.log("accuracy", self.accuracy, on_epoch=True)
383394
self.log("mse", self.mse, on_epoch=True)
384-
self.log_dict(self.metric, on_epoch=True)
395+
self.log_dict(self.multitask, on_epoch=True)
396+
self.log_dict(self.multitask_collection, on_epoch=True)
385397
return self.step(batch)
386398

387399
model = TestModel()
@@ -404,6 +416,10 @@ def training_step(self, batch, batch_idx):
404416
assert torch.allclose(logged["accuracy_epoch"], logged["classification_epoch"])
405417
assert torch.allclose(logged["mse_step"], logged["regression_step"])
406418
assert torch.allclose(logged["mse_epoch"], logged["regression_epoch"])
419+
assert "regression_MeanAbsoluteError_epoch" in logged
420+
assert "regression_MeanSquaredError_epoch" in logged
421+
assert "classification_BinaryAccuracy_epoch" in logged
422+
assert "classification_BinaryAveragePrecision_epoch" in logged
407423

408424

409425
def test_scriptable(tmpdir):

tests/unittests/wrappers/test_multitask.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,55 @@ def test_nested_multitask_wrapper():
209209
assert _dict_results_same_as_individual_results(classification_results, regression_results, multitask_results)
210210

211211

212+
@pytest.mark.parametrize("method", ["keys", "items", "values"])
213+
@pytest.mark.parametrize("flatten", [True, False])
214+
def test_key_value_items_method(method, flatten):
215+
"""Test the keys, items, and values methods of the MultitaskWrapper."""
216+
multitask = MultitaskWrapper(
217+
{
218+
"classification": MetricCollection([BinaryAccuracy(), BinaryF1Score()]),
219+
"regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
220+
}
221+
)
222+
if method == "keys":
223+
output = list(multitask.keys(flatten=flatten))
224+
elif method == "items":
225+
output = list(multitask.items(flatten=flatten))
226+
elif method == "values":
227+
output = list(multitask.values(flatten=flatten))
228+
229+
if flatten:
230+
assert len(output) == 4
231+
if method == "keys":
232+
assert output == [
233+
"classification_BinaryAccuracy",
234+
"classification_BinaryF1Score",
235+
"regression_MeanSquaredError",
236+
"regression_MeanAbsoluteError",
237+
]
238+
elif method == "items":
239+
assert output == [
240+
("classification_BinaryAccuracy", BinaryAccuracy()),
241+
("classification_BinaryF1Score", BinaryF1Score()),
242+
("regression_MeanSquaredError", MeanSquaredError()),
243+
("regression_MeanAbsoluteError", MeanAbsoluteError()),
244+
]
245+
elif method == "values":
246+
assert output == [BinaryAccuracy(), BinaryF1Score(), MeanSquaredError(), MeanAbsoluteError()]
247+
else:
248+
assert len(output) == 2
249+
if method == "keys":
250+
assert output == ["classification", "regression"]
251+
elif method == "items":
252+
assert output[0][0] == "classification"
253+
assert output[1][0] == "regression"
254+
assert isinstance(output[0][1], MetricCollection)
255+
assert isinstance(output[1][1], MetricCollection)
256+
elif method == "values":
257+
assert isinstance(output[0], MetricCollection)
258+
assert isinstance(output[1], MetricCollection)
259+
260+
212261
def test_clone_with_prefix_and_postfix():
213262
"""Check that the clone method works with prefix and postfix arguments."""
214263
multitask_metrics = MultitaskWrapper({"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()})

0 commit comments

Comments
 (0)