From 6f890344d62213a00d55a793ec6af5b7272141eb Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 7 Feb 2024 17:23:32 +0100 Subject: [PATCH] Fix `FeatureShare` wrapper on GPU (#2348) --- CHANGELOG.md | 2 +- src/torchmetrics/wrappers/feature_share.py | 3 ++- tests/unittests/wrappers/test_feature_share.py | 17 +++++++++++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b74406f13f..ad35113a0f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348)) --- diff --git a/src/torchmetrics/wrappers/feature_share.py b/src/torchmetrics/wrappers/feature_share.py index 14a4c051c64..b1fe0451534 100644 --- a/src/torchmetrics/wrappers/feature_share.py +++ b/src/torchmetrics/wrappers/feature_share.py @@ -34,7 +34,8 @@ class NetworkCache(Module): def __init__(self, network: Module, max_size: int = 100) -> None: super().__init__() self.max_size = max_size - self.network = lru_cache(maxsize=self.max_size)(network) + self.network = network + self.network.forward = lru_cache(maxsize=self.max_size)(network.forward) def forward(self, *args: Any, **kwargs: Any) -> Any: """Call the network with the given arguments.""" diff --git a/tests/unittests/wrappers/test_feature_share.py b/tests/unittests/wrappers/test_feature_share.py index 364b94f8bb5..adaeba3fd53 100644 --- a/tests/unittests/wrappers/test_feature_share.py +++ b/tests/unittests/wrappers/test_feature_share.py @@ -95,14 +95,14 @@ def test_memory(): fid = FrechetInceptionDistance(feature=64).cuda() inception = InceptionScore(feature=64).cuda() - kid = KernelInceptionDistance(feature=64).cuda() + kid = KernelInceptionDistance(feature=64, subset_size=5).cuda() memory_before_fs = torch.cuda.memory_allocated() assert memory_before_fs > base_memory, "The memory usage should be higher after initializing the metrics." torch.cuda.empty_cache() - FeatureShare([fid, inception, kid]).cuda() + feature_share = FeatureShare([fid, inception, kid]).cuda() memory_after_fs = torch.cuda.memory_allocated() assert ( @@ -112,6 +112,19 @@ def test_memory(): memory_after_fs < memory_before_fs ), "The memory usage should be higher after initializing the feature share wrapper." + img1 = torch.randint(255, (50, 3, 220, 220), dtype=torch.uint8).to("cuda") + img2 = torch.randint(255, (50, 3, 220, 220), dtype=torch.uint8).to("cuda") + + feature_share.update(img1, real=True) + feature_share.update(img2, real=False) + res = feature_share.compute() + + assert "cuda" in str(res["FrechetInceptionDistance"].device) + assert "cuda" in str(res["InceptionScore"][0].device) + assert "cuda" in str(res["InceptionScore"][1].device) + assert "cuda" in str(res["KernelInceptionDistance"][0].device) + assert "cuda" in str(res["KernelInceptionDistance"][1].device) + def test_same_result_as_individual(): """Test that the feature share wrapper gives the same result as the individual metrics."""