Skip to content

Commit

Permalink
Fix FeatureShare wrapper on GPU (#2348)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Feb 7, 2024
1 parent d5719d0 commit 6f89034
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348))


---
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/wrappers/feature_share.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class NetworkCache(Module):
def __init__(self, network: Module, max_size: int = 100) -> None:
super().__init__()
self.max_size = max_size
self.network = lru_cache(maxsize=self.max_size)(network)
self.network = network
self.network.forward = lru_cache(maxsize=self.max_size)(network.forward)

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Call the network with the given arguments."""
Expand Down
17 changes: 15 additions & 2 deletions tests/unittests/wrappers/test_feature_share.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def test_memory():

fid = FrechetInceptionDistance(feature=64).cuda()
inception = InceptionScore(feature=64).cuda()
kid = KernelInceptionDistance(feature=64).cuda()
kid = KernelInceptionDistance(feature=64, subset_size=5).cuda()

memory_before_fs = torch.cuda.memory_allocated()
assert memory_before_fs > base_memory, "The memory usage should be higher after initializing the metrics."

torch.cuda.empty_cache()

FeatureShare([fid, inception, kid]).cuda()
feature_share = FeatureShare([fid, inception, kid]).cuda()
memory_after_fs = torch.cuda.memory_allocated()

assert (
Expand All @@ -112,6 +112,19 @@ def test_memory():
memory_after_fs < memory_before_fs
), "The memory usage should be higher after initializing the feature share wrapper."

img1 = torch.randint(255, (50, 3, 220, 220), dtype=torch.uint8).to("cuda")
img2 = torch.randint(255, (50, 3, 220, 220), dtype=torch.uint8).to("cuda")

feature_share.update(img1, real=True)
feature_share.update(img2, real=False)
res = feature_share.compute()

assert "cuda" in str(res["FrechetInceptionDistance"].device)
assert "cuda" in str(res["InceptionScore"][0].device)
assert "cuda" in str(res["InceptionScore"][1].device)
assert "cuda" in str(res["KernelInceptionDistance"][0].device)
assert "cuda" in str(res["KernelInceptionDistance"][1].device)


def test_same_result_as_individual():
"""Test that the feature share wrapper gives the same result as the individual metrics."""
Expand Down

0 comments on commit 6f89034

Please sign in to comment.