Skip to content

Commit f84d4d9

Browse files
lezcanopytorchmergebot
authored andcommitted
Deprecate torch.lu_solve
**BC-breaking note**: This PR deprecates `torch.lu_solve` in favor of `torch.linalg.lu_solve_factor`. A upgrade guide is added to the documentation for `torch.lu_solve`. Note this PR DOES NOT remove `torch.lu_solve`. Pull Request resolved: pytorch#73806 Approved by: https://github.com/IvanYashchuk, https://github.com/nikitaved, https://github.com/mruberry
1 parent a5bbfd9 commit f84d4d9

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,7 +1470,7 @@ std::tuple<Tensor,Tensor> solve(const Tensor& self, const Tensor& A) {
14701470
"torch.solve is deprecated in favor of torch.linalg.solve",
14711471
"and will be removed in a future PyTorch release.\n",
14721472
"torch.linalg.solve has its arguments reversed and does not return the LU factorization.\n",
1473-
"To get the LU factorization see torch.linalg.lu_factor, which can be used with torch.lu_solve or torch.lu_unpack.\n",
1473+
"To get the LU factorization see torch.linalg.lu_factor.\n",
14741474
"X = torch.solve(B, A).solution\n",
14751475
"should be replaced with\n",
14761476
"X = torch.linalg.solve(A, B)"
@@ -1489,7 +1489,7 @@ std::tuple<Tensor&,Tensor&> solve_out(const Tensor& self, const Tensor& A, Tenso
14891489
"torch.solve is deprecated in favor of torch.linalg.solve",
14901490
"and will be removed in a future PyTorch release.\n",
14911491
"torch.linalg.solve has its arguments reversed and does not return the LU factorization.\n",
1492-
"To get the LU factorization see torch.linalg.lu_factor, which can be used with torch.lu_solve or torch.lu_unpack.\n",
1492+
"To get the LU factorization see torch.linalg.lu_factor.\n",
14931493
"X = torch.solve(B, A).solution\n",
14941494
"should be replaced with\n",
14951495
"X = torch.linalg.solve(A, B)"
@@ -2439,10 +2439,26 @@ TORCH_IMPL_FUNC(linalg_lu_solve_out)(const Tensor& LU,
24392439
}
24402440

24412441
Tensor lu_solve(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots) {
2442+
TORCH_WARN_ONCE(
2443+
"torch.lu_solve is deprecated in favor of torch.linalg.lu_solve",
2444+
"and will be removed in a future PyTorch release.\n",
2445+
"Note that torch.linalg.lu_solve has its arguments reversed.\n",
2446+
"X = torch.lu_solve(B, LU, pivots)\n",
2447+
"should be replaced with\n",
2448+
"X = torch.linalg.lu_solve(LU, pivots, B)"
2449+
);
24422450
return at::linalg_lu_solve(LU_data, LU_pivots, self);
24432451
}
24442452

24452453
Tensor& lu_solve_out(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, Tensor& result) {
2454+
TORCH_WARN_ONCE(
2455+
"torch.lu_solve is deprecated in favor of torch.linalg.lu_solve",
2456+
"and will be removed in a future PyTorch release.\n",
2457+
"Note that torch.linalg.lu_solve has its arguments reversed.\n",
2458+
"X = torch.lu_solve(B, LU, pivots)\n",
2459+
"should be replaced with\n",
2460+
"X = torch.linalg.lu_solve(LU, pivots, B)"
2461+
);
24462462
return at::linalg_lu_solve_out(result, LU_data, LU_pivots, self);
24472463
}
24482464

torch/_torch_docs.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4339,7 +4339,7 @@ def merge_dicts(*dicts):
43394339
and will be removed in a future PyTorch release.
43404340
:func:`torch.linalg.solve` has its arguments reversed and does not return the
43414341
LU factorization of the input. To get the LU factorization see :func:`torch.linalg.lu_factor`,
4342-
which may be used with :func:`torch.lu_solve`.
4342+
which may be used with :func:`torch.linalg.lu_solve`.
43434343
43444344
``X = torch.solve(B, A).solution`` should be replaced with
43454345
@@ -5813,6 +5813,16 @@ def merge_dicts(*dicts):
58135813
58145814
This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`.
58155815
5816+
.. warning::
5817+
5818+
:func:`torch.lu_solve` is deprecated in favor of :func:`torch.linalg.lu_solve`.
5819+
:func:`torch.lu_solve` will be removed in a future PyTorch release.
5820+
``X = torch.lu_solve(B, LU, pivots)`` should be replaced with
5821+
5822+
.. code:: python
5823+
5824+
X = torch.linalg.lu_solve(LU, pivots, B)
5825+
58165826
Arguments:
58175827
b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*`
58185828
is zero or more batch dimensions.

0 commit comments

Comments
 (0)