Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compositional logging with lightning #1761

Merged
merged 24 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ab17a18
tests
SkafteNicki May 6, 2023
d75b3ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 6, 2023
8fc2fac
Merge branch 'master' into lightning_logging
Borda May 9, 2023
7d9968e
Merge branch 'master' into lightning_logging
SkafteNicki May 10, 2023
54e3b40
Merge branch 'master' into lightning_logging
Borda May 15, 2023
1c5fe5e
Merge branch 'master' into lightning_logging
SkafteNicki May 22, 2023
8f976cd
tests
SkafteNicki May 22, 2023
6e73d26
fix implementation
SkafteNicki May 22, 2023
a071462
Merge branch 'master' into lightning_logging
SkafteNicki May 22, 2023
2e2b245
changelog
SkafteNicki May 22, 2023
0dff3b3
Merge branch 'master' into lightning_logging
SkafteNicki May 22, 2023
0c438d6
Merge branch 'master' into lightning_logging
SkafteNicki May 23, 2023
fda65e4
Merge branch 'master' into lightning_logging
SkafteNicki May 23, 2023
780d23f
Merge branch 'master' into lightning_logging
SkafteNicki May 25, 2023
e09cf42
Merge branch 'master' into lightning_logging
Borda May 29, 2023
fb1ceff
Merge branch 'master' into lightning_logging
Borda May 30, 2023
6d54b28
Merge branch 'master' into lightning_logging
SkafteNicki Jun 5, 2023
71779fc
Merge branch 'master' into lightning_logging
Borda Jun 5, 2023
eda9e83
Merge branch 'master' into lightning_logging
SkafteNicki Jun 6, 2023
522af72
Merge branch 'master' into lightning_logging
Borda Jun 6, 2023
5f4ea42
Merge branch 'master' into lightning_logging
Borda Jun 12, 2023
0fe99a5
Merge branch 'master' into lightning_logging
Borda Jun 13, 2023
f720f1e
Merge branch 'master' into lightning_logging
Borda Jun 13, 2023
f393232
Merge branch 'master' into lightning_logging
mergify[bot] Jun 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed lookup for punkt sources being downloaded in `RougeScore` ([#1789](https://github.com/Lightning-AI/torchmetrics/pull/1789))


- Fixed integration with lightning for `CompositionalMetric` ([#1761](https://github.com/Lightning-AI/torchmetrics/pull/1761))


- Fixed several bugs in `SpectralDistortionIndex` metric ([#1808](https://github.com/Lightning-AI/torchmetrics/pull/1808))


Expand Down
12 changes: 8 additions & 4 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,17 +1086,21 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
)

if val_a is None:
return None
self._forward_cache = None
return self._forward_cache

if val_b is None:
if isinstance(self.metric_b, Metric):
return None
self._forward_cache = None
return self._forward_cache

# Unary op
return self.op(val_a)
self._forward_cache = self.op(val_a)
return self._forward_cache

# Binary op
return self.op(val_a, val_b)
self._forward_cache = self.op(val_a, val_b)
return self._forward_cache

def reset(self) -> None:
"""Redirect the call to the input which the conposition was formed from."""
Expand Down
129 changes: 112 additions & 17 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@

if module_available("lightning"):
from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import CSVLogger
else:
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger

from torchmetrics import MetricCollection
from torchmetrics.aggregation import SumMetric
Expand Down Expand Up @@ -180,44 +182,137 @@ def test_metric_lightning_log(tmpdir):
class TestModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.metric_step = SumMetric()
self.metric_epoch = SumMetric()
self.register_buffer("sum", torch.tensor(0.0))
self.outs = []

def on_train_epoch_start(self):
self.sum = torch.tensor(0.0, device=self.sum.device)
# initiliaze one metric for every combination of `on_step` and `on_epoch` and `forward` and `update`
self.metric_update = SumMetric()
self.metric_update_step = SumMetric()
self.metric_update_epoch = SumMetric()

self.metric_forward = SumMetric()
self.metric_forward_step = SumMetric()
self.metric_forward_epoch = SumMetric()

self.compo_update = SumMetric() + SumMetric()
self.compo_update_step = SumMetric() + SumMetric()
self.compo_update_epoch = SumMetric() + SumMetric()

self.compo_forward = SumMetric() + SumMetric()
self.compo_forward_step = SumMetric() + SumMetric()
self.compo_forward_epoch = SumMetric() + SumMetric()

self.sum = []

def training_step(self, batch, batch_idx):
x = batch
self.metric_step(x.sum())
self.sum += x.sum()
self.log("sum_step", self.metric_step, on_epoch=True, on_step=False)
self.outs.append(x)
return self.step(x)
s = x.sum()

def on_train_epoch_end(self):
self.log("sum_epoch", self.metric_epoch(torch.stack(self.outs)))
self.outs = []
for metric in [self.metric_update, self.metric_update_step, self.metric_update_epoch]:
metric.update(s)
for metric in [self.metric_forward, self.metric_forward_step, self.metric_forward_epoch]:
_ = metric(s)
for metric in [self.compo_update, self.compo_update_step, self.compo_update_epoch]:
metric.update(s)
for metric in [self.compo_forward, self.compo_forward_step, self.compo_forward_epoch]:
_ = metric(s)

self.sum.append(s)

self.log("metric_update", self.metric_update)
self.log("metric_update_step", self.metric_update_step, on_epoch=False, on_step=True)
self.log("metric_update_epoch", self.metric_update_epoch, on_epoch=True, on_step=False)

self.log("metric_forward", self.metric_forward)
self.log("metric_forward_step", self.metric_forward_step, on_epoch=False, on_step=True)
self.log("metric_forward_epoch", self.metric_forward_epoch, on_epoch=True, on_step=False)

self.log("compo_update", self.compo_update)
self.log("compo_update_step", self.compo_update_step, on_epoch=False, on_step=True)
self.log("compo_update_epoch", self.compo_update_epoch, on_epoch=True, on_step=False)

self.log("compo_forward", self.compo_forward)
self.log("compo_forward_step", self.compo_forward_step, on_epoch=False, on_step=True)
self.log("compo_forward_epoch", self.compo_forward_epoch, on_epoch=True, on_step=False)

return self.step(x)

model = TestModel()

logger = CSVLogger("tmpdir/logs")
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=0,
max_epochs=2,
log_every_n_steps=1,
logger=logger,
)
with no_warning_call(
UserWarning,
match="Torchmetrics v0.9 introduced a new argument class property called.*",
):
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(tensor(logged["sum_step"]), model.sum, atol=2e-4)
assert torch.allclose(tensor(logged["sum_epoch"]), model.sum, atol=2e-4)
logged_metrics = logger._experiment.metrics

epoch_0_step_0 = logged_metrics[0]
assert "metric_forward" in epoch_0_step_0
assert epoch_0_step_0["metric_forward"] == model.sum[0]
assert "metric_forward_step" in epoch_0_step_0
assert epoch_0_step_0["metric_forward_step"] == model.sum[0]
assert "compo_forward" in epoch_0_step_0
assert epoch_0_step_0["compo_forward"] == 2 * model.sum[0]
assert "compo_forward_step" in epoch_0_step_0
assert epoch_0_step_0["compo_forward_step"] == 2 * model.sum[0]

epoch_0_step_1 = logged_metrics[1]
assert "metric_forward" in epoch_0_step_1
assert epoch_0_step_1["metric_forward"] == model.sum[1]
assert "metric_forward_step" in epoch_0_step_1
assert epoch_0_step_1["metric_forward_step"] == model.sum[1]
assert "compo_forward" in epoch_0_step_1
assert epoch_0_step_1["compo_forward"] == 2 * model.sum[1]
assert "compo_forward_step" in epoch_0_step_1
assert epoch_0_step_1["compo_forward_step"] == 2 * model.sum[1]

epoch_0 = logged_metrics[2]
assert "metric_update_epoch" in epoch_0
assert epoch_0["metric_update_epoch"] == sum([model.sum[0], model.sum[1]])
assert "metric_forward_epoch" in epoch_0
assert epoch_0["metric_forward_epoch"] == sum([model.sum[0], model.sum[1]])
assert "compo_update_epoch" in epoch_0
assert epoch_0["compo_update_epoch"] == 2 * sum([model.sum[0], model.sum[1]])
assert "compo_forward_epoch" in epoch_0
assert epoch_0["compo_forward_epoch"] == 2 * sum([model.sum[0], model.sum[1]])

epoch_1_step_0 = logged_metrics[3]
assert "metric_forward" in epoch_1_step_0
assert epoch_1_step_0["metric_forward"] == model.sum[2]
assert "metric_forward_step" in epoch_1_step_0
assert epoch_1_step_0["metric_forward_step"] == model.sum[2]
assert "compo_forward" in epoch_1_step_0
assert epoch_1_step_0["compo_forward"] == 2 * model.sum[2]
assert "compo_forward_step" in epoch_1_step_0
assert epoch_1_step_0["compo_forward_step"] == 2 * model.sum[2]

epoch_1_step_1 = logged_metrics[4]
assert "metric_forward" in epoch_1_step_1
assert epoch_1_step_1["metric_forward"] == model.sum[3]
assert "metric_forward_step" in epoch_1_step_1
assert epoch_1_step_1["metric_forward_step"] == model.sum[3]
assert "compo_forward" in epoch_1_step_1
assert epoch_1_step_1["compo_forward"] == 2 * model.sum[3]
assert "compo_forward_step" in epoch_1_step_1
assert epoch_1_step_1["compo_forward_step"] == 2 * model.sum[3]

epoch_1 = logged_metrics[5]
assert "metric_update_epoch" in epoch_1
assert epoch_1["metric_update_epoch"] == sum([model.sum[2], model.sum[3]])
assert "metric_forward_epoch" in epoch_1
assert epoch_1["metric_forward_epoch"] == sum([model.sum[2], model.sum[3]])
assert "compo_update_epoch" in epoch_1
assert epoch_1["compo_update_epoch"] == 2 * sum([model.sum[2], model.sum[3]])
assert "compo_forward_epoch" in epoch_1
assert epoch_1["compo_forward_epoch"] == 2 * sum([model.sum[2], model.sum[3]])


def test_metric_collection_lightning_log(tmpdir):
Expand Down