Skip to content

Commit 9dc8f25

Browse files
lezcanopytorchmergebot
authored andcommitted
Update torch.lu_unpack docs
As per title Pull Request resolved: pytorch#73803 Approved by: https://github.com/IvanYashchuk, https://github.com/nikitaved, https://github.com/mruberry
1 parent fc5b4a5 commit 9dc8f25

File tree

2 files changed

+29
-41
lines changed

2 files changed

+29
-41
lines changed

torch/_torch_docs.py

Lines changed: 25 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5754,63 +5754,48 @@ def merge_dicts(*dicts):
57545754
add_docstr(torch.lu_unpack, r"""
57555755
lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor)
57565756
5757-
Unpacks the data and pivots from a LU factorization of a tensor into tensors ``L`` and ``U`` and a permutation tensor ``P``
5758-
such that ``LU_data, LU_pivots = (P @ L @ U).lu()``.
5757+
Unpacks the LU decomposition returned by :func:`~linalg.lu_factor` into the `P, L, U` matrices.
57595758
5760-
Returns a tuple of tensors as ``(the P tensor (permutation matrix), the L tensor, the U tensor)``.
5759+
.. seealso::
57615760
5762-
.. note:: ``P.dtype == LU_data.dtype`` and ``P.dtype`` is not an integer type so that matrix products with ``P``
5763-
are possible without casting it to a floating type.
5761+
:func:`~linalg.lu` returns the matrices from the LU decomposition. It is more efficient than
5762+
doing :func:`~linalg.lu_factor` and then :func:`~linalg.lu_unpack`.
57645763
57655764
Args:
57665765
LU_data (Tensor): the packed LU factorization data
57675766
LU_pivots (Tensor): the packed LU factorization pivots
57685767
unpack_data (bool): flag indicating if the data should be unpacked.
5769-
If ``False``, then the returned ``L`` and ``U`` are ``None``.
5768+
If ``False``, then the returned ``L`` and ``U`` are empty tensors.
57705769
Default: ``True``
57715770
unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``.
5772-
If ``False``, then the returned ``P`` is ``None``.
5771+
If ``False``, then the returned ``P`` is an empty tensor.
57735772
Default: ``True``
5774-
out (tuple, optional): a tuple of three tensors to use for the outputs ``(P, L, U)``.
5773+
5774+
Keyword args:
5775+
out (tuple, optional): output tuple of three tensors. Ignored if `None`.
5776+
5777+
Returns:
5778+
A namedtuple ``(P, L, U)``
57755779
57765780
Examples::
57775781
57785782
>>> A = torch.randn(2, 3, 3)
5779-
>>> A_LU, pivots = A.lu()
5780-
>>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
5781-
>>>
5782-
>>> # can recover A from factorization
5783-
>>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
5783+
>>> LU, pivots = torch.linalg.lu_factor(A)
5784+
>>> P, L, U = torch.lu_unpack(LU, pivots)
5785+
>>> # We can recover A from the factorization
5786+
>>> A_ = P @ L @ U
5787+
>>> torch.allclose(A, A_)
5788+
True
57845789
57855790
>>> # LU factorization of a rectangular matrix:
57865791
>>> A = torch.randn(2, 3, 2)
5787-
>>> A_LU, pivots = A.lu()
5788-
>>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
5789-
>>> P
5790-
tensor([[[1., 0., 0.],
5791-
[0., 1., 0.],
5792-
[0., 0., 1.]],
5793-
5794-
[[0., 0., 1.],
5795-
[0., 1., 0.],
5796-
[1., 0., 0.]]])
5797-
>>> A_L
5798-
tensor([[[ 1.0000, 0.0000],
5799-
[ 0.4763, 1.0000],
5800-
[ 0.3683, 0.1135]],
5801-
5802-
[[ 1.0000, 0.0000],
5803-
[ 0.2957, 1.0000],
5804-
[-0.9668, -0.3335]]])
5805-
>>> A_U
5806-
tensor([[[ 2.1962, 1.0881],
5807-
[ 0.0000, -0.8681]],
5808-
5809-
[[-1.0947, 0.3736],
5810-
[ 0.0000, 0.5718]]])
5811-
>>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
5812-
>>> torch.norm(A_ - A)
5813-
tensor(2.9802e-08)
5792+
>>> LU, pivots = torch.linalg.lu_factor(A)
5793+
>>> P, L, U = torch.lu_unpack(LU, pivots)
5794+
>>> # P, L, U are the same as returned by linalg.lu
5795+
>>> P_, L_, U_ = torch.linalg.lu(A)
5796+
>>> torch.allclose(P, P_) and torch.allclose(L, L_) and torch.allclose(U, U_)
5797+
True
5798+
58145799
""".format(**common_args))
58155800

58165801
add_docstr(torch.less, r"""

torch/linalg/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2231,8 +2231,11 @@
22312231
:func:`torch.linalg.lu_solve` solves a system of linear equations given the output of this
22322232
function provided the input matrix was square and invertible.
22332233
2234+
:func:`torch.lu_unpack` unpacks the tensors returned by :func:`~lu_factor` into the three
2235+
matrices `P, L, U` that form the decomposition.
2236+
22342237
:func:`torch.linalg.lu` computes the LU decomposition with partial pivoting of a possibly
2235-
non-square matrix.
2238+
non-square matrix. It is a composition of :func:`~lu_factor` and :func:`torch.lu_unpack`.
22362239
22372240
:func:`torch.linalg.solve` solves a system of linear equations. It is a composition
22382241
of :func:`~lu_factor` and :func:`~lu_solve`.

0 commit comments

Comments
 (0)