Skip to content

Commit 5b8c55e

Browse files
committed
BUG: torch: fix count_nonzero(... keepdims=True)
1 parent fc8777f commit 5b8c55e

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,15 +521,22 @@ def diff(
521521
return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
522522

523523

524-
# torch uses `dim` instead of `axis`
524+
# torch uses `dim` instead of `axis`, does not have keepdims
525525
def count_nonzero(
526526
x: array,
527527
/,
528528
*,
529529
axis: Optional[Union[int, Tuple[int, ...]]] = None,
530530
keepdims: bool = False,
531531
) -> array:
532-
return torch.count_nonzero(x, dim=axis, keepdims=keepdims)
532+
result = torch.count_nonzero(x, dim=axis)
533+
if keepdims:
534+
if axis is not None:
535+
return result.unsqueeze(axis)
536+
return _axis_none_keepdims(result, x.ndim, keepdims)
537+
else:
538+
return result
539+
533540

534541

535542
def where(condition: array, x1: array, x2: array, /) -> array:

0 commit comments

Comments
 (0)