Skip to content

Commit

Permalink
[Array API] Add linalg.diagonal (pytorch#70599)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#70599

This PR adds `linalg.diagonal` following the Array API:
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-diagonal-x-axis1-0-axis2-1-offset-0

Fixes pytorch#62813

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33760506

Pulled By: mruberry

fbshipit-source-id: e32c3490321d8c3f31b3bb538bc1f72b39bd2854
(cherry picked from commit 44f41f8)
  • Loading branch information
lezcano authored and pytorchmergebot committed Jan 26, 2022
1 parent fe277b8 commit 108b37d
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 4 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1737,6 +1737,11 @@ Tensor& linalg_matmul_out(const Tensor & tensor1, const Tensor & tensor2, Tensor
return at::native::matmul_out(tensor1, tensor2, result);
}

// torch.linalg.diagonal, alias for torch.diagonal with dim1=-2, dim2=-1 as defaults
Tensor linalg_diagonal(const Tensor& A, int64_t offset, int64_t dim1, int64_t dim2) {
return A.diagonal(offset, dim1, dim2);
}

// helper methods for matrix_exp
namespace {

Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,10 @@
dispatch:
CompositeExplicitAutograd: diagonal

- func: linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a)
python_module: linalg
variants: function

- func: diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a)
variants: function, method

Expand Down
1 change: 1 addition & 0 deletions docs/source/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Matrix Properties
norm
vector_norm
matrix_norm
diagonal
det
slogdet
cond
Expand Down
6 changes: 6 additions & 0 deletions torch/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,12 @@
Alias for :func:`torch.matmul`
""")

diagonal = _add_docstr(_linalg.linalg_diagonal, r"""
linalg.diagonal(A, *, offset=0, dim1=-2, dim2=-1) -> Tensor
Alias for :func:`torch.diagonal` with defaults :attr:`dim1`\ `= -2`, :attr:`dim2`\ `= -1`.
""")

multi_dot = _add_docstr(_linalg.linalg_multi_dot, r"""
linalg.multi_dot(tensors, *, out=None)
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.diagflat: lambda input, offset=0: -1,
torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
torch.digamma: lambda input, out=None: -1,
torch.dist: lambda input, other, p=2: -1,
Expand Down
13 changes: 9 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6022,11 +6022,13 @@ def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **k
# Shapes for 3D Tensors
shapes_3d = ((M, M, M),)

args_2d = ((), (2,), (-2,), (1,))
args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1))
kwargs_2d = (dict(), dict(offset=2), dict(offset=2), dict(offset=1))
kwargs_3d = (dict(offset=1, dim1=1, dim2=2),
dict(offset=2, dim1=0, dim2=1),
dict(offset=-2, dim1=0, dim2=1))

for shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)):
yield SampleInput(make_arg(shape), args=arg)
for shape, kwarg in chain(product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d)):
yield SampleInput(make_arg(shape), kwargs=kwarg)


def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs):
Expand Down Expand Up @@ -9309,6 +9311,9 @@ def ref_pairwise_distance(input1, input2):
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_diagonal_diag_embed),
OpInfo('diagonal',
# They are not strictly aliases as they have diverging defaults, but we can see them as aliases for testing purposes
# If we add tests that test the function against the alias, make linalg.diagonal into its own OpInfo
aliases=('linalg.diagonal',),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
supports_forward_ad=True,
Expand Down

0 comments on commit 108b37d

Please sign in to comment.