@@ -5754,63 +5754,48 @@ def merge_dicts(*dicts):
5754
5754
add_docstr (torch .lu_unpack , r"""
5755
5755
lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor)
5756
5756
5757
- Unpacks the data and pivots from a LU factorization of a tensor into tensors ``L`` and ``U`` and a permutation tensor ``P``
5758
- such that ``LU_data, LU_pivots = (P @ L @ U).lu()``.
5757
+ Unpacks the LU decomposition returned by :func:`~linalg.lu_factor` into the `P, L, U` matrices.
5759
5758
5760
- Returns a tuple of tensors as ``(the P tensor (permutation matrix), the L tensor, the U tensor)``.
5759
+ .. seealso::
5761
5760
5762
- .. note:: ``P.dtype == LU_data.dtype`` and ``P.dtype`` is not an integer type so that matrix products with ``P``
5763
- are possible without casting it to a floating type .
5761
+ :func:`~linalg.lu` returns the matrices from the LU decomposition. It is more efficient than
5762
+ doing :func:`~linalg.lu_factor` and then :func:`~linalg.lu_unpack` .
5764
5763
5765
5764
Args:
5766
5765
LU_data (Tensor): the packed LU factorization data
5767
5766
LU_pivots (Tensor): the packed LU factorization pivots
5768
5767
unpack_data (bool): flag indicating if the data should be unpacked.
5769
- If ``False``, then the returned ``L`` and ``U`` are ``None`` .
5768
+ If ``False``, then the returned ``L`` and ``U`` are empty tensors .
5770
5769
Default: ``True``
5771
5770
unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``.
5772
- If ``False``, then the returned ``P`` is ``None`` .
5771
+ If ``False``, then the returned ``P`` is an empty tensor .
5773
5772
Default: ``True``
5774
- out (tuple, optional): a tuple of three tensors to use for the outputs ``(P, L, U)``.
5773
+
5774
+ Keyword args:
5775
+ out (tuple, optional): output tuple of three tensors. Ignored if `None`.
5776
+
5777
+ Returns:
5778
+ A namedtuple ``(P, L, U)``
5775
5779
5776
5780
Examples::
5777
5781
5778
5782
>>> A = torch.randn(2, 3, 3)
5779
- >>> A_LU, pivots = A.lu()
5780
- >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
5781
- >>>
5782
- >>> # can recover A from factorization
5783
- >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
5783
+ >>> LU, pivots = torch.linalg.lu_factor(A)
5784
+ >>> P, L, U = torch.lu_unpack(LU, pivots)
5785
+ >>> # We can recover A from the factorization
5786
+ >>> A_ = P @ L @ U
5787
+ >>> torch.allclose(A, A_)
5788
+ True
5784
5789
5785
5790
>>> # LU factorization of a rectangular matrix:
5786
5791
>>> A = torch.randn(2, 3, 2)
5787
- >>> A_LU, pivots = A.lu()
5788
- >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
5789
- >>> P
5790
- tensor([[[1., 0., 0.],
5791
- [0., 1., 0.],
5792
- [0., 0., 1.]],
5793
-
5794
- [[0., 0., 1.],
5795
- [0., 1., 0.],
5796
- [1., 0., 0.]]])
5797
- >>> A_L
5798
- tensor([[[ 1.0000, 0.0000],
5799
- [ 0.4763, 1.0000],
5800
- [ 0.3683, 0.1135]],
5801
-
5802
- [[ 1.0000, 0.0000],
5803
- [ 0.2957, 1.0000],
5804
- [-0.9668, -0.3335]]])
5805
- >>> A_U
5806
- tensor([[[ 2.1962, 1.0881],
5807
- [ 0.0000, -0.8681]],
5808
-
5809
- [[-1.0947, 0.3736],
5810
- [ 0.0000, 0.5718]]])
5811
- >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
5812
- >>> torch.norm(A_ - A)
5813
- tensor(2.9802e-08)
5792
+ >>> LU, pivots = torch.linalg.lu_factor(A)
5793
+ >>> P, L, U = torch.lu_unpack(LU, pivots)
5794
+ >>> # P, L, U are the same as returned by linalg.lu
5795
+ >>> P_, L_, U_ = torch.linalg.lu(A)
5796
+ >>> torch.allclose(P, P_) and torch.allclose(L, L_) and torch.allclose(U, U_)
5797
+ True
5798
+
5814
5799
""" .format (** common_args ))
5815
5800
5816
5801
add_docstr (torch .less , r"""
0 commit comments