Skip to content

Commit a5bbfd9

Browse files
lezcanopytorchmergebot
authored andcommitted
Deprecate torch.lu
**BC-breaking note**: This PR deprecates `torch.lu` in favor of `torch.linalg.lu_factor`. A upgrade guide is added to the documentation for `torch.lu`. Note this PR DOES NOT remove `torch.lu`. Pull Request resolved: pytorch#73804 Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
1 parent 9dc8f25 commit a5bbfd9

File tree

10 files changed

+58
-42
lines changed

10 files changed

+58
-42
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,7 +1470,7 @@ std::tuple<Tensor,Tensor> solve(const Tensor& self, const Tensor& A) {
14701470
"torch.solve is deprecated in favor of torch.linalg.solve",
14711471
"and will be removed in a future PyTorch release.\n",
14721472
"torch.linalg.solve has its arguments reversed and does not return the LU factorization.\n",
1473-
"To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.\n",
1473+
"To get the LU factorization see torch.linalg.lu_factor, which can be used with torch.lu_solve or torch.lu_unpack.\n",
14741474
"X = torch.solve(B, A).solution\n",
14751475
"should be replaced with\n",
14761476
"X = torch.linalg.solve(A, B)"
@@ -1489,7 +1489,7 @@ std::tuple<Tensor&,Tensor&> solve_out(const Tensor& self, const Tensor& A, Tenso
14891489
"torch.solve is deprecated in favor of torch.linalg.solve",
14901490
"and will be removed in a future PyTorch release.\n",
14911491
"torch.linalg.solve has its arguments reversed and does not return the LU factorization.\n",
1492-
"To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.\n",
1492+
"To get the LU factorization see torch.linalg.lu_factor, which can be used with torch.lu_solve or torch.lu_unpack.\n",
14931493
"X = torch.solve(B, A).solution\n",
14941494
"should be replaced with\n",
14951495
"X = torch.linalg.solve(A, B)"
@@ -2256,6 +2256,17 @@ std::tuple<Tensor, Tensor> linalg_lu_factor(const Tensor& A, bool pivot) {
22562256

22572257
// TODO Deprecate this function in favour of linalg_lu_factor_ex
22582258
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool) {
2259+
TORCH_WARN_ONCE(
2260+
"torch.lu is deprecated in favor of torch.linalg.lu_factor / torch.linalg.lu_factor_ex and will be ",
2261+
"removed in a future PyTorch release.\n",
2262+
"LU, pivots = torch.lu(A, compute_pivots)\n",
2263+
"should be replaced with\n",
2264+
"LU, pivots = torch.linalg.lu_factor(A, compute_pivots)\n",
2265+
"and\n",
2266+
"LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)\n",
2267+
"should be replaced with\n",
2268+
"LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)"
2269+
);
22592270
return at::linalg_lu_factor_ex(self, compute_pivots, false);
22602271
}
22612272

aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2092,7 +2092,7 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv
20922092
#if !AT_MAGMA_ENABLED()
20932093
TORCH_CHECK(
20942094
false,
2095-
"Calling torch.lu on a CUDA tensor requires compiling ",
2095+
"Calling torch.linalg.lu_factor on a CUDA tensor requires compiling ",
20962096
"PyTorch with MAGMA. Please rebuild with MAGMA.");
20972097
#else
20982098
auto input_data = input.data_ptr<scalar_t>();

test/mobile/model_test/math_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,9 @@ def blas_lapack_ops(self):
441441
# torch.logdet(m),
442442
# torch.slogdet(m),
443443
# torch.lstsq(m, m),
444-
# torch.lu(m),
445-
# torch.lu_solve(m, *torch.lu(m)),
446-
# torch.lu_unpack(*torch.lu(m)),
444+
# torch.linalg.lu_factor(m),
445+
# torch.lu_solve(m, *torch.linalg.lu_factor(m)),
446+
# torch.lu_unpack(*torch.linalg.lu_factor(m)),
447447
torch.matmul(m, m),
448448
torch.matrix_power(m, 2),
449449
# torch.matrix_rank(m),

test/test_jit.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9144,20 +9144,8 @@ def istft(input, n_fft):
91449144
inps2 = (stft(*inps), inps[1])
91459145
self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2))
91469146

9147-
def lu(x):
9148-
# type: (Tensor) -> Tuple[Tensor, Tensor]
9149-
return torch.lu(x)
9150-
9151-
self.checkScript(lu, (torch.randn(2, 3, 3),))
9152-
9153-
def lu_infos(x):
9154-
# type: (Tensor) -> Tuple[Tensor, Tensor, Tensor]
9155-
return torch.lu(x, get_infos=True)
9156-
9157-
self.checkScript(lu_infos, (torch.randn(2, 3, 3),))
9158-
91599147
def lu_unpack(x):
9160-
A_LU, pivots = torch.lu(x)
9148+
A_LU, pivots = torch.linalg.lu_factor(x)
91619149
return torch.lu_unpack(A_LU, pivots)
91629150

91639151
for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)):

test/test_linalg.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4368,7 +4368,7 @@ def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_c
43684368
size_b = size_b[1:]
43694369

43704370
if well_conditioned:
4371-
PLU = torch.lu_unpack(*torch.lu(make_randn(*size_a)))
4371+
PLU = torch.linalg.lu(make_randn(*size_a))
43724372
if uni:
43734373
# A = L from PLU
43744374
A = PLU[1].transpose(-2, -1).contiguous()
@@ -5055,15 +5055,6 @@ def call_torch_fn(*args, **kwargs):
50555055
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
50565056
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True))
50575057

5058-
if torch._C.has_lapack:
5059-
# lu
5060-
A_LU, pivots = fn(torch.lu, (0, 5, 5))
5061-
self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape])
5062-
A_LU, pivots = fn(torch.lu, (0, 0, 0))
5063-
self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape])
5064-
A_LU, pivots = fn(torch.lu, (2, 0, 0))
5065-
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
5066-
50675058
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
50685059
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
50695060
@dtypesIfCUDA(*floating_and_complex_types_and(
@@ -5431,7 +5422,7 @@ def gen_matrices():
54315422
@dtypes(torch.double)
54325423
def test_lu_unpack_check_input(self, device, dtype):
54335424
x = torch.rand(5, 5, 5, device=device, dtype=dtype)
5434-
lu_data, lu_pivots = torch.lu(x, pivot=True)
5425+
lu_data, lu_pivots = torch.linalg.lu_factor(x)
54355426

54365427
with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"):
54375428
torch.lu_unpack(lu_data, lu_pivots.long())
@@ -7320,7 +7311,7 @@ def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
73207311

73217312
b = torch.randn(*b_dims, dtype=dtype, device=device)
73227313
A = make_A(*A_dims)
7323-
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
7314+
LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A)
73247315
self.assertEqual(info, torch.zeros_like(info))
73257316
return b, A, LU_data, LU_pivots
73267317

@@ -7364,7 +7355,7 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
73647355
# Tests tensors with 0 elements
73657356
b = torch.randn(3, 0, 3, dtype=dtype, device=device)
73667357
A = torch.randn(3, 0, 0, dtype=dtype, device=device)
7367-
LU_data, LU_pivots = torch.lu(A)
7358+
LU_data, LU_pivots = torch.linalg.lu_factor(A)
73687359
self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))
73697360

73707361
sub_test(True)
@@ -7399,7 +7390,7 @@ def run_test(A_dims, b_dims, pivot=True):
73997390
A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size)
74007391
b = make_tensor(b_dims, dtype=dtype, device=device)
74017392
x_exp = np.linalg.solve(A.cpu(), b.cpu())
7402-
LU_data, LU_pivots = torch.lu(A, pivot=pivot)
7393+
LU_data, LU_pivots = torch.linalg.lu_factor(A)
74037394
x = torch.lu_solve(b, LU_data, LU_pivots)
74047395
self.assertEqual(x, x_exp)
74057396

torch/_torch_docs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5809,16 +5809,16 @@ def merge_dicts(*dicts):
58095809
lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor
58105810
58115811
Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted
5812-
LU factorization of A from :meth:`torch.lu`.
5812+
LU factorization of A from :func:`~linalg.lu_factor`.
58135813
58145814
This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`.
58155815
58165816
Arguments:
58175817
b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*`
58185818
is zero or more batch dimensions.
5819-
LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu` of size :math:`(*, m, m)`,
5819+
LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`,
58205820
where :math:`*` is zero or more batch dimensions.
5821-
LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`torch.lu` of size :math:`(*, m)`,
5821+
LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`,
58225822
where :math:`*` is zero or more batch dimensions.
58235823
The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of
58245824
:attr:`LU_data`.
@@ -5830,9 +5830,9 @@ def merge_dicts(*dicts):
58305830
58315831
>>> A = torch.randn(2, 3, 3)
58325832
>>> b = torch.randn(2, 3, 1)
5833-
>>> A_LU = torch.lu(A)
5834-
>>> x = torch.lu_solve(b, *A_LU)
5835-
>>> torch.norm(torch.bmm(A, x) - b)
5833+
>>> LU, pivots = torch.linalg.lu_factor(A)
5834+
>>> x = torch.lu_solve(b, LU, pivots)
5835+
>>> torch.dist(A @ x, b)
58365836
tensor(1.00000e-07 *
58375837
2.8312)
58385838
""".format(**common_args))

torch/backends/cuda/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,14 @@ def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend]
134134
* :func:`torch.linalg.cholesky_ex`
135135
* :func:`torch.cholesky_solve`
136136
* :func:`torch.cholesky_inverse`
137-
* :func:`torch.lu`
137+
* :func:`torch.linalg.lu_factor`
138+
* :func:`torch.linalg.lu`
139+
* :func:`torch.linalg.lu_solve`
138140
* :func:`torch.linalg.qr`
139141
* :func:`torch.linalg.eigh`
142+
* :func:`torch.linalg.eighvals`
140143
* :func:`torch.linalg.svd`
144+
* :func:`torch.linalg.svdvals`
141145
'''
142146

143147
if backend is None:

torch/functional.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,23 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
15401540
pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to
15411541
``True``.
15421542
1543+
.. warning::
1544+
1545+
:func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
1546+
and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
1547+
future PyTorch release.
1548+
``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with
1549+
1550+
.. code:: python
1551+
1552+
LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
1553+
1554+
``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with
1555+
1556+
.. code:: python
1557+
1558+
LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
1559+
15431560
.. note::
15441561
* The returned permutation matrix for every matrix in the batch is
15451562
represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.

torch/linalg/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@
336336
Also supports batches of matrices, and if :attr:`A` is a batch of matrices then
337337
the output has the same batch dimensions.
338338
339+
""" + fr"""
340+
.. note:: This function is computed using :func:`torch.linalg.lu_factor`.
341+
{common_notes["sync_note"]}
342+
""" + r"""
343+
339344
.. seealso::
340345
341346
:func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the
@@ -372,7 +377,7 @@
372377
the output has the same batch dimensions.
373378
374379
""" + fr"""
375-
.. note:: This function is computed using :func:`torch.lu`.
380+
.. note:: This function is computed using :func:`torch.linalg.lu_factor`.
376381
{common_notes["sync_note"]}
377382
""" + r"""
378383

torch/testing/_internal/common_methods_invocations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6381,7 +6381,7 @@ def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwarg
63816381
make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
63826382

63836383
def out_fn(output):
6384-
if op_info.name in ("linalg.lu"):
6384+
if op_info.name == "linalg.lu":
63856385
return output[1], output[2]
63866386
else:
63876387
return output

0 commit comments

Comments
 (0)