From 8ef2a89973127fd839597af81c62f98f790be649 Mon Sep 17 00:00:00 2001 From: Kyle Dorman Date: Sat, 25 Nov 2023 12:44:08 -0800 Subject: [PATCH] Use arange and repeat for deterministic bincount (#2184) * Use meshgrid for deterministic bincount * Update src/torchmetrics/utilities/data.py Use arange instead of meashgrid * Update data.py Update _bincount doc * chlog * size * Update CHANGELOG.md update changelog * Update data.py update comment * improve text --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 +++ src/torchmetrics/utilities/data.py | 18 ++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae0db13a079..4b637a56eb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145)) +- Use arange and repeat for deterministic bincount ([#2184](https://github.com/Lightning-AI/torchmetrics/pull/2184)) + + ### Deprecated - Deprecated `metric._update_called` ([#2141](https://github.com/Lightning-AI/torchmetrics/pull/2141)) diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 8e818a144f7..09c28b9fc16 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -169,12 +169,10 @@ def _squeeze_if_scalar(data: Any) -> Any: def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: """Implement custom bincount. - PyTorch currently does not support ``torch.bincount`` for: - - - deterministic mode on GPU. - - MPS devices - - This implementation fallback to a for-loop counting occurrences in that case. + PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running + MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of + `torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption + as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``. Args: x: tensor to count @@ -191,11 +189,11 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: """ if minlength is None: minlength = len(torch.unique(x)) + if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: - output = torch.zeros(minlength, device=x.device, dtype=torch.long) - for i in range(minlength): - output[i] = (x == i).sum() - return output + mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1) + return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0) + return torch.bincount(x, minlength=minlength)