Skip to content

Commit

Permalink
Add linalg.lu_factor (pytorch#66933)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#66933

This PR exposes `torch.lu` as `torch.linalg.lu_factor` and
`torch.linalg.lu_factor_ex`.

This PR also adds support for matrices with zero elements both in
the size of the matrix and the batch. Note that this function simply
returns empty tensors of the correct size in this case.

We add a test and an OpInfo for the new function.

This PR also adds documentation for this new function in line of
the documentation in the rest of `torch.linalg`.

Fixes pytorch#56590
Fixes pytorch#64014

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D32834069

Pulled By: mruberry

fbshipit-source-id: 51ef12535fa91d292f419acf83b800b86ee9c7eb
  • Loading branch information
lezcano authored and facebook-github-bot committed Jan 6, 2022
1 parent 3f53365 commit a35b4b4
Show file tree
Hide file tree
Showing 22 changed files with 597 additions and 336 deletions.
119 changes: 99 additions & 20 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
result = result.unsqueeze_(-1);
}

// lu_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
// lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
result.copy_(other_broadcasted);

auto input_working_copy = cloneBatchedColumnMajor(input_broadcasted);
Expand All @@ -945,7 +945,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
auto pivots_shape = IntArrayRef(input_broadcasted.sizes().data(), input_broadcasted.dim() - 2).vec(); // input_broadcasted.shape[:-2]
pivots_shape.push_back(std::min(input.size(-2), input.size(-1)));
Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt));
lu_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);
lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);

// solve the linear system using the LU factorization
lu_solve_stub(input.device().type(), result, input_working_copy, pivots);
Expand Down Expand Up @@ -1571,30 +1571,109 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) {
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

DEFINE_DISPATCH(lu_stub);
DEFINE_DISPATCH(lu_factor_stub);

// TODO: remove check_errors argument
// https://github.com/pytorch/pytorch/issues/64014
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool check_errors) {
TORCH_CHECK(self.dim() >= 2,
"expected tensor with 2 or more dimensions, got size: ", self.sizes(),
" instead");
auto m = self.size(-2);
auto n = self.size(-1);
auto req_size = self.sizes().vec();
std::tuple<Tensor&, Tensor&, Tensor&> linalg_lu_factor_ex_out(const Tensor& A,
bool pivot,
bool check_errors,
Tensor& LU,
Tensor& pivots,
Tensor& info) {
TORCH_CHECK(A.dim() >= 2,
"expected tensor with 2 or more dimensions, got size: ", A.sizes(), " instead");
auto req_size = A.sizes().vec();
const auto m = req_size.cend()[-2];
const auto n = req_size.cend()[-1];

// TODO reimplementation of resize_output with format F-contiguous
// We should make this a standalone function
if (resize_output_check(LU, req_size)) {
// Transpose size
std::iter_swap(req_size.end() - 1, req_size.end() - 2);
LU.resize_(req_size, MemoryFormat::Contiguous);
LU.transpose_(-2, -1); // make 'LU' have Fortran contiguous memory
}
req_size.pop_back();
req_size.back() = std::min(m, n);
auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt));
at::native::resize_output(pivots, req_size);
req_size.pop_back();
auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
at::native::resize_output(info, req_size);

const auto LU_f_contig = LU.transpose(-2, -1).is_contiguous() ;

if (LU_f_contig && !LU.is_same(A)) {
LU.copy_(A);
}
const auto LU_ = borrow_else_clone(LU_f_contig, LU, A, /*C-contig*/false);

const auto pivots_contig = pivots.is_contiguous();
const auto pivots_ = borrow_else_clone(pivots_contig, pivots, pivots, /*C-contig*/true);

const auto info_contig = info.is_contiguous();
const auto info_ = borrow_else_clone(info_contig, info, info, /*C-contig*/true);

lu_factor_stub(A.device().type(), *LU_, *pivots_, *info_, pivot);

if (!LU_f_contig) {
LU.copy_(*LU_);
}
if (!pivots_contig) {
pivots.copy_(*pivots_);
}
if (!info_contig) {
info.copy_(*info_);
}

if (check_errors) {
if (A.dim() > 2) {
batchCheckErrors(info, "torch.linalg.lu_factor_ex");
} else {
singleCheckErrors(info.item<int64_t>(), "torch.linalg.lu_factor_ex");
}
}

return std::tie(LU, pivots, info);
}

std::tuple<Tensor, Tensor, Tensor> linalg_lu_factor_ex(const Tensor& A, bool pivot, bool check_errors) {
auto LU = at::empty({0}, A.options());
auto pivots = at::empty({0}, A.options().dtype(kInt));
auto info = at::empty({0}, A.options().dtype(kInt));
at::native::linalg_lu_factor_ex_out(A, pivot, check_errors, LU, pivots, info);
return std::make_tuple(std::move(LU), std::move(pivots), std::move(info));
}

std::tuple<Tensor&, Tensor&> linalg_lu_factor_out(const Tensor& A, bool pivot, Tensor & LU, Tensor & pivots) {
auto info = at::empty({0}, A.options().dtype(kInt));
// We pass check_errors as we want to use lu_factor rather than lu_factor_ex in the errors
at::linalg_lu_factor_ex_out(LU, pivots, info, A, pivot, /*chech_errors=*/false);
if (A.dim() > 2) {
batchCheckErrors(info, "torch.linalg.lu_factor");
} else {
singleCheckErrors(info.item<int64_t>(), "torch.linalg.lu_factor");
}

return std::tie(LU, pivots);
}

std::tuple<Tensor, Tensor> linalg_lu_factor(const Tensor& A, bool pivot) {
Tensor LU, pivots, info;
std::tie(LU, pivots, info) = at::linalg_lu_factor_ex(A, pivot, /*check_errors=*/false);

if (A.dim() > 2) {
batchCheckErrors(info, "torch.linalg.lu_factor");
} else {
singleCheckErrors(info.item<int64_t>(), "torch.linalg.lu_factor");
}

return std::make_tuple(std::move(LU), std::move(pivots));
}

// lu_stub (apply_lu) requires batched column major (Fortran-contiguous) tensors
// 'lu' tensor is modified in-place and must be a copy of 'self'
Tensor lu = cloneBatchedColumnMajor(self);
lu_stub(self.device().type(), lu, pivots_tensor, infos_tensor, compute_pivots);
return std::make_tuple(lu, pivots_tensor, infos_tensor);
// TODO Deprecate this function in favour of linalg_lu_factor_ex
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool) {
return at::linalg_lu_factor_ex(self, compute_pivots, false);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ using triangular_solve_fn = void (*)(
bool /*unitriangular*/);
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);

using lu_fn = void (*)(
using lu_factor_fn = void (*)(
const Tensor& /*input*/,
const Tensor& /*pivots*/,
const Tensor& /*infos*/,
bool /*compute_pivots*/);
DECLARE_DISPATCH(lu_fn, lu_stub);
DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);

using lu_solve_fn = void (*)(
const Tensor& /*b*/,
Expand Down
22 changes: 11 additions & 11 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -847,14 +847,14 @@ void triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool u
For further details, please see the LAPACK documentation for GETRF.
*/
template <typename scalar_t>
void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#if !AT_BUILD_WITH_LAPACK()
TORCH_CHECK(
false,
"Calling torch.lu on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
TORCH_CHECK(compute_pivots, "lu without pivoting is not implemented on the CPU");
TORCH_CHECK(compute_pivots, "linalg.lu_factor: LU without pivoting is not implemented on the CPU");

auto input_data = input.data_ptr<scalar_t>();
auto pivots_data = pivots.data_ptr<int>();
Expand All @@ -876,9 +876,9 @@ void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bo
}

// This is a type dispatching helper function for 'apply_lu'
void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
void lu_factor_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_cpu", [&]{
apply_lu<scalar_t>(input, pivots, infos, compute_pivots);
apply_lu_factor<scalar_t>(input, pivots, infos, compute_pivots);
});
}

Expand All @@ -890,8 +890,8 @@ void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, b
Args:
* `b` - [in] the right hand side matrix B
[out] the solution matrix X
* `lu` - [in] the LU factorization of matrix A (see at::_lu_with_info)
* `pivots` - [in] the pivot indices (see at::_lu_with_info)
* `lu` - [in] the LU factorization of matrix A (see at::linalg_lu_factor)
* `pivots` - [in] the pivot indices (see at::linalg_lu_factor)
For further details, please see the LAPACK documentation for GETRS.
*/
Expand Down Expand Up @@ -1005,11 +1005,11 @@ REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);

REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel);
REGISTER_AVX512_DISPATCH(lu_stub, &lu_kernel);
REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel);
REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel);
REGISTER_ZVECTOR_DISPATCH(lu_stub, &lu_kernel);
REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel);
REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel);

REGISTER_ARCH_DISPATCH(lu_solve_trans_stub, DEFAULT, &lu_solve_trans_kernel);
REGISTER_AVX512_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_kernel);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ DEFINE_DISPATCH(linalg_vector_norm_stub);
// where info helps us identify singular matrices.
static inline std::tuple<c10::ExclusivelyOwned<Tensor>, c10::ExclusivelyOwned<Tensor>> _lu_det_P_diag_U(const Tensor& self) {
Tensor pivs, lu, infos;
std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
std::tie(lu, pivs, infos) = at::linalg_lu_factor_ex(self);
TORCH_CHECK(infos.ge(0).all().item<uint8_t>(), "Invalid argument passed to lu");
auto n = self.size(-1);
auto num_exchanges = (at::arange(1, n + 1, pivs.options()) != pivs)
Expand All @@ -135,7 +135,7 @@ static inline std::tuple<c10::ExclusivelyOwned<Tensor>, c10::ExclusivelyOwned<Te
// det(A) = ([is P odd] * -2 + 1) * prod(diag(U))
std::tuple<Tensor, Tensor, Tensor> _det_lu_based_helper(const Tensor& self) {
Tensor lu, pivs, infos;
std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors*/false);
std::tie(lu, pivs, infos) = at::linalg_lu_factor_ex(self);
TORCH_CHECK(infos.ge(0).all().item<uint8_t>(), "at::_det_lu_based_helper(): Invalid argument passed to LU");

// find det(P)
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
return result;
}

/*
* contig chooses between C-contig (true) and F-contig (false)
*/
static inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
: c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
: cloneBatchedColumnMajor(clone));
}

/*
* This method is designed to be a faster alternative to
* `cloneBatchedColumnMajor` with some additional features,
Expand Down Expand Up @@ -280,6 +289,11 @@ static inline void singleCheckErrors(int64_t info, const char* name, int64_t bat
} else if (strstr(name, "lstsq")) {
TORCH_CHECK_LINALG(false, name, batch_string,
": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ").");
} else if (strstr(name, "lu_factor")) {
TORCH_CHECK(false, name, batch_string,
": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. "
"If you still want to perform the factorization, consider calling linalg.lu(A, pivot) or "
"linalg.lu_factor_ex(A, pivot)");
} else {
TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, ".");
}
Expand Down
Loading

0 comments on commit a35b4b4

Please sign in to comment.