diff --git a/CHANGELOG.md b/CHANGELOG.md index b6c78c535f9..6df503fc56c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 --- +## [UnReleased] - 2024-MM-DD + +### Added + +- + + +### Changed + +- + + +### Deprecated + +- + + +### Fixed + +- Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348)) + + ## [1.3.0] - 2024-01-10 ### Added diff --git a/src/torchmetrics/wrappers/feature_share.py b/src/torchmetrics/wrappers/feature_share.py index 14a4c051c64..b1fe0451534 100644 --- a/src/torchmetrics/wrappers/feature_share.py +++ b/src/torchmetrics/wrappers/feature_share.py @@ -34,7 +34,8 @@ class NetworkCache(Module): def __init__(self, network: Module, max_size: int = 100) -> None: super().__init__() self.max_size = max_size - self.network = lru_cache(maxsize=self.max_size)(network) + self.network = network + self.network.forward = lru_cache(maxsize=self.max_size)(network.forward) def forward(self, *args: Any, **kwargs: Any) -> Any: """Call the network with the given arguments.""" diff --git a/tests/unittests/wrappers/test_feature_share.py b/tests/unittests/wrappers/test_feature_share.py index 364b94f8bb5..adaeba3fd53 100644 --- a/tests/unittests/wrappers/test_feature_share.py +++ b/tests/unittests/wrappers/test_feature_share.py @@ -95,14 +95,14 @@ def test_memory(): fid = FrechetInceptionDistance(feature=64).cuda() inception = InceptionScore(feature=64).cuda() - kid = KernelInceptionDistance(feature=64).cuda() + kid = KernelInceptionDistance(feature=64, subset_size=5).cuda() memory_before_fs = torch.cuda.memory_allocated() assert memory_before_fs > base_memory, "The memory usage should be higher after initializing the metrics." torch.cuda.empty_cache() - FeatureShare([fid, inception, kid]).cuda() + feature_share = FeatureShare([fid, inception, kid]).cuda() memory_after_fs = torch.cuda.memory_allocated() assert ( @@ -112,6 +112,19 @@ def test_memory(): memory_after_fs < memory_before_fs ), "The memory usage should be higher after initializing the feature share wrapper." + img1 = torch.randint(255, (50, 3, 220, 220), dtype=torch.uint8).to("cuda") + img2 = torch.randint(255, (50, 3, 220, 220), dtype=torch.uint8).to("cuda") + + feature_share.update(img1, real=True) + feature_share.update(img2, real=False) + res = feature_share.compute() + + assert "cuda" in str(res["FrechetInceptionDistance"].device) + assert "cuda" in str(res["InceptionScore"][0].device) + assert "cuda" in str(res["InceptionScore"][1].device) + assert "cuda" in str(res["KernelInceptionDistance"][0].device) + assert "cuda" in str(res["KernelInceptionDistance"][1].device) + def test_same_result_as_individual(): """Test that the feature share wrapper gives the same result as the individual metrics."""