diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index c4f21ce267165..912cddd4259c1 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -928,7 +928,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor result = result.unsqueeze_(-1); } - // lu_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted' + // lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted' result.copy_(other_broadcasted); auto input_working_copy = cloneBatchedColumnMajor(input_broadcasted); @@ -945,7 +945,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor auto pivots_shape = IntArrayRef(input_broadcasted.sizes().data(), input_broadcasted.dim() - 2).vec(); // input_broadcasted.shape[:-2] pivots_shape.push_back(std::min(input.size(-2), input.size(-1))); Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt)); - lu_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true); + lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true); // solve the linear system using the LU factorization lu_solve_stub(input.device().type(), result, input_working_copy, pivots); @@ -1571,30 +1571,109 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) { return result; } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -DEFINE_DISPATCH(lu_stub); +DEFINE_DISPATCH(lu_factor_stub); -// TODO: remove check_errors argument -// https://github.com/pytorch/pytorch/issues/64014 -std::tuple _lu_with_info(const Tensor& self, bool compute_pivots, bool check_errors) { - TORCH_CHECK(self.dim() >= 2, - "expected tensor with 2 or more dimensions, got size: ", self.sizes(), - " instead"); - auto m = self.size(-2); - auto n = self.size(-1); - auto req_size = self.sizes().vec(); +std::tuple linalg_lu_factor_ex_out(const Tensor& A, + bool pivot, + bool check_errors, + Tensor& LU, + Tensor& pivots, + Tensor& info) { + TORCH_CHECK(A.dim() >= 2, + "expected tensor with 2 or more dimensions, got size: ", A.sizes(), " instead"); + auto req_size = A.sizes().vec(); + const auto m = req_size.cend()[-2]; + const auto n = req_size.cend()[-1]; + + // TODO reimplementation of resize_output with format F-contiguous + // We should make this a standalone function + if (resize_output_check(LU, req_size)) { + // Transpose size + std::iter_swap(req_size.end() - 1, req_size.end() - 2); + LU.resize_(req_size, MemoryFormat::Contiguous); + LU.transpose_(-2, -1); // make 'LU' have Fortran contiguous memory + } req_size.pop_back(); req_size.back() = std::min(m, n); - auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt)); + at::native::resize_output(pivots, req_size); req_size.pop_back(); - auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt)); + at::native::resize_output(info, req_size); + + const auto LU_f_contig = LU.transpose(-2, -1).is_contiguous() ; + + if (LU_f_contig && !LU.is_same(A)) { + LU.copy_(A); + } + const auto LU_ = borrow_else_clone(LU_f_contig, LU, A, /*C-contig*/false); + + const auto pivots_contig = pivots.is_contiguous(); + const auto pivots_ = borrow_else_clone(pivots_contig, pivots, pivots, /*C-contig*/true); + + const auto info_contig = info.is_contiguous(); + const auto info_ = borrow_else_clone(info_contig, info, info, /*C-contig*/true); + + lu_factor_stub(A.device().type(), *LU_, *pivots_, *info_, pivot); + + if (!LU_f_contig) { + LU.copy_(*LU_); + } + if (!pivots_contig) { + pivots.copy_(*pivots_); + } + if (!info_contig) { + info.copy_(*info_); + } + + if (check_errors) { + if (A.dim() > 2) { + batchCheckErrors(info, "torch.linalg.lu_factor_ex"); + } else { + singleCheckErrors(info.item(), "torch.linalg.lu_factor_ex"); + } + } + + return std::tie(LU, pivots, info); +} + +std::tuple linalg_lu_factor_ex(const Tensor& A, bool pivot, bool check_errors) { + auto LU = at::empty({0}, A.options()); + auto pivots = at::empty({0}, A.options().dtype(kInt)); + auto info = at::empty({0}, A.options().dtype(kInt)); + at::native::linalg_lu_factor_ex_out(A, pivot, check_errors, LU, pivots, info); + return std::make_tuple(std::move(LU), std::move(pivots), std::move(info)); +} + +std::tuple linalg_lu_factor_out(const Tensor& A, bool pivot, Tensor & LU, Tensor & pivots) { + auto info = at::empty({0}, A.options().dtype(kInt)); + // We pass check_errors as we want to use lu_factor rather than lu_factor_ex in the errors + at::linalg_lu_factor_ex_out(LU, pivots, info, A, pivot, /*chech_errors=*/false); + if (A.dim() > 2) { + batchCheckErrors(info, "torch.linalg.lu_factor"); + } else { + singleCheckErrors(info.item(), "torch.linalg.lu_factor"); + } + + return std::tie(LU, pivots); +} + +std::tuple linalg_lu_factor(const Tensor& A, bool pivot) { + Tensor LU, pivots, info; + std::tie(LU, pivots, info) = at::linalg_lu_factor_ex(A, pivot, /*check_errors=*/false); + + if (A.dim() > 2) { + batchCheckErrors(info, "torch.linalg.lu_factor"); + } else { + singleCheckErrors(info.item(), "torch.linalg.lu_factor"); + } + + return std::make_tuple(std::move(LU), std::move(pivots)); +} - // lu_stub (apply_lu) requires batched column major (Fortran-contiguous) tensors - // 'lu' tensor is modified in-place and must be a copy of 'self' - Tensor lu = cloneBatchedColumnMajor(self); - lu_stub(self.device().type(), lu, pivots_tensor, infos_tensor, compute_pivots); - return std::make_tuple(lu, pivots_tensor, infos_tensor); +// TODO Deprecate this function in favour of linalg_lu_factor_ex +std::tuple _lu_with_info(const Tensor& self, bool compute_pivots, bool) { + return at::linalg_lu_factor_ex(self, compute_pivots, false); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index 816c437d93f12..bf5ffc13c475d 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -219,12 +219,12 @@ using triangular_solve_fn = void (*)( bool /*unitriangular*/); DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub); -using lu_fn = void (*)( +using lu_factor_fn = void (*)( const Tensor& /*input*/, const Tensor& /*pivots*/, const Tensor& /*infos*/, bool /*compute_pivots*/); -DECLARE_DISPATCH(lu_fn, lu_stub); +DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub); using lu_solve_fn = void (*)( const Tensor& /*b*/, diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 593001e139727..b83ea1de9f443 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -847,14 +847,14 @@ void triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool u For further details, please see the LAPACK documentation for GETRF. */ template -void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { #if !AT_BUILD_WITH_LAPACK() TORCH_CHECK( false, "Calling torch.lu on a CPU tensor requires compiling ", "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); #else - TORCH_CHECK(compute_pivots, "lu without pivoting is not implemented on the CPU"); + TORCH_CHECK(compute_pivots, "linalg.lu_factor: LU without pivoting is not implemented on the CPU"); auto input_data = input.data_ptr(); auto pivots_data = pivots.data_ptr(); @@ -876,9 +876,9 @@ void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bo } // This is a type dispatching helper function for 'apply_lu' -void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +void lu_factor_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_cpu", [&]{ - apply_lu(input, pivots, infos, compute_pivots); + apply_lu_factor(input, pivots, infos, compute_pivots); }); } @@ -890,8 +890,8 @@ void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, b Args: * `b` - [in] the right hand side matrix B [out] the solution matrix X - * `lu` - [in] the LU factorization of matrix A (see at::_lu_with_info) - * `pivots` - [in] the pivot indices (see at::_lu_with_info) + * `lu` - [in] the LU factorization of matrix A (see at::linalg_lu_factor) + * `pivots` - [in] the pivot indices (see at::linalg_lu_factor) For further details, please see the LAPACK documentation for GETRS. */ @@ -1005,11 +1005,11 @@ REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); -REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel); -REGISTER_AVX512_DISPATCH(lu_stub, &lu_kernel); -REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel); -REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel); -REGISTER_ZVECTOR_DISPATCH(lu_stub, &lu_kernel); +REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel); +REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_ARCH_DISPATCH(lu_solve_trans_stub, DEFAULT, &lu_solve_trans_kernel); REGISTER_AVX512_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_kernel); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 1a127050ed5fe..ea306585f3025 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -119,7 +119,7 @@ DEFINE_DISPATCH(linalg_vector_norm_stub); // where info helps us identify singular matrices. static inline std::tuple, c10::ExclusivelyOwned> _lu_det_P_diag_U(const Tensor& self) { Tensor pivs, lu, infos; - std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false); + std::tie(lu, pivs, infos) = at::linalg_lu_factor_ex(self); TORCH_CHECK(infos.ge(0).all().item(), "Invalid argument passed to lu"); auto n = self.size(-1); auto num_exchanges = (at::arange(1, n + 1, pivs.options()) != pivs) @@ -135,7 +135,7 @@ static inline std::tuple, c10::ExclusivelyOwned _det_lu_based_helper(const Tensor& self) { Tensor lu, pivs, infos; - std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors*/false); + std::tie(lu, pivs, infos) = at::linalg_lu_factor_ex(self); TORCH_CHECK(infos.ge(0).all().item(), "at::_det_lu_based_helper(): Invalid argument passed to LU"); // find det(P) diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 0345cc99eb803..b5d2b705dea34 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -59,6 +59,15 @@ static inline Tensor cloneBatchedColumnMajor(const Tensor& src) { return result; } +/* + * contig chooses between C-contig (true) and F-contig (false) + */ +static inline c10::MaybeOwned borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) { + return cond ? c10::MaybeOwned::borrowed(borrow) + : c10::MaybeOwned::owned(contig ? clone.clone(MemoryFormat::Contiguous) + : cloneBatchedColumnMajor(clone)); +} + /* * This method is designed to be a faster alternative to * `cloneBatchedColumnMajor` with some additional features, @@ -280,6 +289,11 @@ static inline void singleCheckErrors(int64_t info, const char* name, int64_t bat } else if (strstr(name, "lstsq")) { TORCH_CHECK_LINALG(false, name, batch_string, ": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ")."); + } else if (strstr(name, "lu_factor")) { + TORCH_CHECK(false, name, batch_string, + ": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. " + "If you still want to perform the factorization, consider calling linalg.lu(A, pivot) or " + "linalg.lu_factor_ex(A, pivot)"); } else { TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, "."); } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp index ecbd5ca6e8a23..e199774a27903 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp @@ -5,6 +5,8 @@ #include #include +#include + #include #include #include @@ -1806,12 +1808,10 @@ REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); For further details, please see the MAGMA documentation for magma_dgetrf_gpu. */ template -static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +static void apply_lu_factor_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { #if !AT_MAGMA_ENABLED() - TORCH_CHECK( - false, - "Calling torch.lu on a CUDA tensor requires compiling ", - "PyTorch with MAGMA. Please rebuild with MAGMA."); + // This should never be thrown if the calling functions are correct. + AT_ERROR("linalg.lu_factor: PyTorch was not compiled with MAGMA support."); #else // magmaLu and magmaLuNoPiv require infos and pivots tensor to be on CPU // the data is later copied back to the appropriate output tensor @@ -1835,20 +1835,15 @@ static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, con int* infos_working_ptr = &infos_data[i]; magmaLu(m, n, input_working_ptr, leading_dimension, pivots_working_ptr, infos_working_ptr); } - pivots.copy_(pivots_cpu, /*non_blocking=*/true); + pivots.copy_(pivots_cpu); } else { for (decltype(batch_size) i = 0; i < batch_size; i++) { scalar_t* input_working_ptr = &input_data[i * input_matrix_stride]; int* infos_working_ptr = &infos_data[i]; magmaLuNoPiv(m, n, input_working_ptr, leading_dimension, infos_working_ptr); } - - // fill the pivots tensor with indices using 1-based (Fortran) indexing - auto k = std::min(m, n); - Tensor pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)).expand_as(pivots); - pivots.copy_(pivots_tmp); } - infos.copy_(infos_cpu, /*non_blocking=*/true); + infos.copy_(infos_cpu); #endif } @@ -1867,7 +1862,7 @@ static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, con For further details, please see the MAGMA documentation for magma_dgetrf_batched. */ template -static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { #if !AT_MAGMA_ENABLED() TORCH_CHECK( false, @@ -1879,13 +1874,6 @@ static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, co auto input_matrix_stride = matrixStride(input); magma_int_t batch_size = magma_int_cast(batchCount(input), "batchCount"); - // magmaLuBatched doesn't work with zero batch dimensions - // it gives CUDA error: invalid configuration argument - if (batch_size == 0) { - infos.fill_(0); - return; - } - magma_int_t m = magma_int_cast(input.size(-2), "m"); magma_int_t n = magma_int_cast(input.size(-1), "n"); auto leading_dimension = std::max(1, m); @@ -1915,11 +1903,6 @@ static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, co magmaLuBatched(m, n, input_array, leading_dimension, pivots_array, infos_data, batch_size, magma_queue); } else { magmaLuNoPivBatched(m, n, input_array, leading_dimension, infos_data, batch_size, magma_queue); - - // fill the pivots tensor with indices using 1-based (Fortran) indexing - auto k = std::min(m, n); - Tensor pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)).expand_as(pivots); - pivots.copy_(pivots_tmp); } // block CPU until all operations on the queue are finished @@ -1928,57 +1911,98 @@ static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, co #endif } -static void lu_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_magma_looped", [&]{ - apply_lu_looped_magma(input, pivots, infos, compute_pivots); +static void lu_factor_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_factor_magma_looped", [&]{ + apply_lu_factor_looped_magma(input, pivots, infos, compute_pivots); }); } -static void lu_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_magma_batched", [&]{ - apply_lu_batched_magma(input, pivots, infos, compute_pivots); +static void lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_factor_magma_batched", [&]{ + apply_lu_factor_batched_magma(input, pivots, infos, compute_pivots); }); } -static void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { - int64_t batch_size = batchCount(input); +static void lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { + auto batch_size = batchCount(input); + // MAGMA does not work with batch_size == 0. + // CuSolver does not work when the matrices have no elements + if (input.numel() == 0) { + // zero out the infos as it will have one element if the input is a matrix of size (0, 0) + infos.zero_(); + return; + } + +#if AT_MAGMA_ENABLED() + const auto lu_factor_magma = [batch_size](const Tensor& input, const Tensor& pivots, const Tensor& infos, const bool compute_pivots) { + if (batch_size == 1) { + lu_factor_looped_magma(input, pivots, infos, compute_pivots); + } else { + // There is a bug in lu_factor_batched_magma in MAGMA < 2.5.2, see + // https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on + std::tuple version; + magma_version(&std::get<0>(version), &std::get<1>(version), &std::get<2>(version)); + if (version >= std::make_tuple(2, 5, 2)) { + lu_factor_batched_magma(input, pivots, infos, compute_pivots); + } else { + lu_factor_looped_magma(input, pivots, infos, compute_pivots); + } + } + }; +#endif + #ifdef USE_CUSOLVER auto preferred_backend = at::globalContext().linalgPreferredBackend(); switch (preferred_backend) { case at::LinalgBackend::Cusolver: - lu_looped_cusolver(input, pivots, infos, compute_pivots); + lu_factor_looped_cusolver(input, pivots, infos, compute_pivots, use_magma_); break; case at::LinalgBackend::Magma: - if (batch_size == 1) { - lu_looped_magma(input, pivots, infos, compute_pivots); - } else { - lu_batched_magma(input, pivots, infos, compute_pivots); - } +#if AT_MAGMA_ENABLED() + lu_factor_magma(input, pivots, infos, compute_pivots); break; +#endif default: - // Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes. - auto m = input.size(-2); - // exclude complex128 since nan_to_num_ does not work with it. - if ((batch_size == 1 || - (batch_size <= 8 && m <= 16) || - !use_magma_) - && !input.is_complex()) { - lu_looped_cusolver(input, pivots, infos, compute_pivots); +#if AT_MAGMA_ENABLED() + // We do not use cuSOLVER for complex inputs if !get_pivots since nan_to_num_ does not work with it. + // See https://github.com/pytorch/pytorch/issues/59247 for more info + // Provided the above, use a heuristic to determine that cusolver is faster than MAGMA + const auto m = input.size(-2); + const auto use_cusolver = ((batch_size == 1 || (batch_size <= 8 && m <= 16)) + && (!input.is_complex() || compute_pivots)); + if (use_cusolver) { + lu_factor_looped_cusolver(input, pivots, infos, compute_pivots, use_magma_); } else { - lu_batched_magma(input, pivots, infos, compute_pivots); + lu_factor_magma(input, pivots, infos, compute_pivots); } +#else // USE_CUSOLVER && !AT_MAGMA_ENABLED + lu_factor_looped_cusolver(input, pivots, infos, compute_pivots, use_magma_); +#endif } +#else // !USE_CUSOLVER +#if AT_MAGMA_ENABLED() + if (batch_size == 1) { + lu_factor_looped_magma(input, pivots, infos, compute_pivots); + } else { + lu_factor_magma(input, pivots, infos, compute_pivots); + } #else - if (batch_size == 1) { - lu_looped_magma(input, pivots, infos, compute_pivots); - } - else { - lu_batched_magma(input, pivots, infos, compute_pivots); - } + TORCH_CHECK( + false, + "Calling linalg.lu_factor on a CUDA tensor requires compiling ", + "PyTorch with MAGMA or cuSolver. Please rebuild with MAGMA."); +#endif // AT_MAGMA_ENABLED #endif // USE_CUSOLVER + + // We return the trivial permutation of pivots starting with 1 (FORTRAN indexing) + if (!compute_pivots) { + auto k = std::min(input.size(-2), input.size(-1)); + auto pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)); + pivots.copy_(pivots_tmp); + } } -REGISTER_CUDA_DISPATCH(lu_stub, &apply_lu); +REGISTER_CUDA_DISPATCH(lu_factor_stub, &lu_factor); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2795,8 +2819,8 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som Args: * `b` - [in] the right hand side matrix B [out] the solution matrix X - * `lu` - [in] the LU factorization of matrix A (see at::_lu_with_info) - * `pivots` - [in] the pivot indices (see at::_lu_with_info) + * `lu` - [in] the LU factorization of matrix A (see at::linalg_lu_factor) + * `pivots` - [in] the pivot indices (see at::linalg_lu_factor) For further details, please see the MAGMA documentation for magma_dgetrs_gpu. */ @@ -2849,8 +2873,8 @@ static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Args: * `b` - [in] the right hand side matrix B [out] the solution matrix X - * `lu` - [in] the LU factorization of matrix A (see at::_lu_with_info) - * `pivots` - [in] the pivot indices (see at::_lu_with_info) + * `lu` - [in] the LU factorization of matrix A (see at::linalg_lu_factor) + * `pivots` - [in] the pivot indices (see at::linalg_lu_factor) For further details, please see the MAGMA documentation for magma_dgetrs_batched. */ diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp index 1bb368866c11a..fd982a630c150 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp @@ -1423,56 +1423,40 @@ void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, // The 'apply_' word is used for templated by dtype functions that call an API routine // underneath. Since the cusolver API has a slightly different structure we do not prepend // apply_ to this function. -void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots) { - // Fill the pivots tensor with indices using 1-based (Fortran) indexing. This - // is needed for maintaining the same results with MAGMA. - auto k = std::min(self.size(-2), self.size(-1)); - Tensor pivots_tmp = at::arange(1, k + 1, self.options().dtype(at::kInt)).expand_as(pivots); - pivots.copy_(pivots_tmp); - - AT_DISPATCH_FLOATING_TYPES( +void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots, const bool use_magma_) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( self.scalar_type(), - "lu_cusolver", + "lu_factor_cusolver", [&self, &pivots, &infos, &get_pivots]() { - int m = cuda_int_cast(self.size(-2), "m"); - int n = cuda_int_cast(self.size(-1), "n"); - int lda = std::max(1, m); - int64_t self_stride = matrixStride(self); - int64_t batch_size = batchCount(self); - scalar_t* self_data = self.data_ptr(); - int* infos_data = infos.data_ptr(); - - auto handle = at::cuda::getCurrentCUDASolverDnHandle(); + const auto m = cuda_int_cast(self.size(-2), "m"); + const auto n = cuda_int_cast(self.size(-1), "n"); + const auto lda = std::max(1, m); + const auto self_stride = matrixStride(self); + const auto batch_size = batchCount(self); + const auto self_data = self.data_ptr(); + const auto infos_data = infos.data_ptr(); + + const auto pivots_data = get_pivots ? pivots.data_ptr() : nullptr; + const auto pivots_stride = get_pivots ? pivots.size(-1) : 0; + + const auto handle = at::cuda::getCurrentCUDASolverDnHandle(); for (auto batch = decltype(batch_size){0}; batch < batch_size; ++batch) { - if (get_pivots) { - auto pivots_data = pivots.data_ptr(); - auto pivots_stride = pivots.size(-1); - at::cuda::solver::getrf( - handle, m, n, - self_data + batch * self_stride, - lda, - pivots_data + batch * pivots_stride, - infos_data + batch - ); - } - else { - at::cuda::solver::getrf( - handle, m, n, - self_data + batch * self_stride, - lda, - nullptr, - infos_data + batch - ); - } + at::cuda::solver::getrf( + handle, m, n, + self_data + batch * self_stride, + lda, + get_pivots ? pivots_data + batch * pivots_stride : nullptr, + infos_data + batch + ); } }); // Necessary because cuSOLVER uses nan for outputs that correspond to 0 in MAGMA for non-pivoted LU. // See https://github.com/pytorch/pytorch/issues/53879 for more details. - if (!get_pivots) { + if (!get_pivots && use_magma_) { at::nan_to_num_(const_cast(self), 0, std::numeric_limits::infinity(), -std::numeric_limits::infinity()); } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h index 6b7b3adc10d7c..d2119c7144b31 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h @@ -61,7 +61,7 @@ Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau); void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors); void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose); -void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots); +void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots, const bool use_magma_); #endif // USE_CUSOLVER diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d4c26f3254bfb..d0a1acca63c44 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7093,8 +7093,6 @@ - func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info) variants: function - dispatch: - CPU, CUDA: _lu_with_info - func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -10768,6 +10766,27 @@ dispatch: CPU, CUDA: linalg_cross_out +# linalg.lu_factor +- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) + python_module: linalg + variants: function + +- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) + python_module: linalg + variants: function + +- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_lu_factor_ex + +- func: linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_lu_factor_ex_out + - func: linalg_det(Tensor self) -> Tensor python_module: linalg variants: function diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 2962f666e74bc..2e2a31a0f0a3b 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -33,6 +33,7 @@ Decompositions cholesky qr + lu_factor eig eigvals eigh @@ -101,3 +102,4 @@ Experimental Functions cholesky_ex inv_ex + lu_factor_ex diff --git a/test/test_linalg.py b/test/test_linalg.py index 041e46edc8408..fbc8f0a7f5b54 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -5788,87 +5788,72 @@ def test_householder_product_errors_and_warnings(self, device): with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): torch.linalg.householder_product(reflectors, tau) - @precisionOverride({torch.complex64: 5e-6}) - @skipCUDAIfNoMagma + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack - @dtypes(torch.double, torch.cfloat, torch.cdouble) - def test_lu(self, device, dtype): + @dtypes(*floating_and_complex_types()) + def test_linalg_lu_factor_and_lu(self, device, dtype): + # Tests lu, linalg.lu_factor and linalg.lu_factor_ex from torch.testing._internal.common_utils import random_matrix - def run_test(device, pivot): - def run_subtest(matrix_size, batches, device, pivot, singular=False, a=None): - if isinstance(matrix_size, int): - rows = columns = matrix_size - else: - rows, columns = matrix_size - if a is None: - a = random_matrix(rows, columns, *batches, **dict(singular=singular, dtype=dtype, device=device)) - a_LU_info, pivots_info, info_ = a.lu(pivot=pivot, get_infos=True) - self.assertEqual(a_LU_info.size(), torch.Size(batches + (rows, columns))) - self.assertEqual(pivots_info.size(), torch.Size(batches + (min(rows, columns),))) - self.assertEqual(info_.size(), torch.Size(batches)) - # If a randomly generated input matrix is singular, - # then info_ contains indices i such that U[i, i] == - # 0. This however conveys that the factorization was - # successful albeit with a singular input. Therefore, - # we require info.min() >= 0 - self.assertGreaterEqual(info_.min(), 0) - a_LU, pivots = a.lu(pivot=pivot) - self.assertEqual(a_LU, a_LU_info) - self.assertEqual(pivots_info, pivots) - - - P, L, U = torch.lu_unpack(a_LU, pivots) - P_ = P.cpu().numpy() - L_ = L.cpu().numpy() - U_ = U.cpu().numpy() - - self.assertEqual(np.matmul(P_, np.matmul(L_, U_)), a) + def run_test(A, pivot, singular, fn): + k = min(A.shape[-2:]) + batch = A.shape[:-2] + check_errors = (fn == torch.linalg.lu_factor) + if singular and check_errors: + # It may or may not throw as the LU decomposition without pivoting + # may still succeed for singular matrices + try: + LU, pivots = fn(A, pivot=pivot) + except RuntimeError: + return + else: + LU, pivots = fn(A, pivot=pivot)[:2] - if self.device_type == 'cuda': - # lu without pivoting is implemented only for cuda device - a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True) - P_nopiv, L_nopiv, U_nopiv = torch.lu_unpack(a_LU_info_nopiv, nopiv) - P_nopiv_ = P_nopiv.cpu().numpy() - L_nopiv_ = L_nopiv.cpu().numpy() - U_nopiv_ = U_nopiv.cpu().numpy() - - self.assertEqual(np.matmul(P_nopiv_, np.matmul(L_nopiv_, U_nopiv_)), a) - - k = min(rows, columns) - self.assertEqual(nopiv, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(a.shape[:-2] + (k, ))) - if not singular: - # It is not guaranteed that LU factorization - # without pivoting is able to determine if a - # matrix is singular while LU factorization - # with pivoting is. Therefore, we require the - # equality of info-s only for non-singular - # matrices. - # NOTE: infor_ is reshaped because info_nopiv might have - # squashed batch dimensions for complex types on CUDA, - # see the TODOs above. - self.assertEqual(info_.reshape(info_nopiv.shape), info_nopiv) - - for ms, batch in itertools.product([3, 5, 7, (4, 2), (3, 4)], [(), (2,), (3,), (3, 5)]): - run_subtest(ms, batch, device, pivot) - run_subtest(ms, batch, device, pivot, singular=True) - - # Reproducer of a magma bug, see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on - a = torch.ones(batch + (ms if isinstance(ms, tuple) else (ms, ms)), dtype=torch.double, device=device) - run_subtest(ms, batch, device, pivot, singular=True, a=a) - - # Info should be positive for rank deficient matrices - a = torch.ones(5, 3, 3, device=device) - self.assertGreater(a.lu(pivot=pivot, get_infos=True)[2][0], 0) - - run_test(device, True) + self.assertEqual(LU.size(), A.shape) + self.assertEqual(pivots.size(), batch + (k,)) + + if not pivot: + self.assertEqual(pivots, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(batch + (k, ))) + + P, L, U = torch.lu_unpack(LU, pivots) + + self.assertEqual(P @ L @ U, A) + + sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0)) + batches = ((0,), (2,), (3,), (1, 0), (3, 5)) + # Non pivoting just implemented for CUDA + pivots = (True, False) if self.device_type == "cuda" else (True,) + fns = (partial(torch.lu, get_infos=True), torch.linalg.lu_factor, torch.linalg.lu_factor_ex) + for ms, batch, pivot, singular, fn in itertools.product(sizes, batches, pivots, (True, False), fns): + m, n = ms + A = random_matrix(m, n, *batch, singular=singular, dtype=dtype, device=device) + # Just do one of them on singular matrices + if A.numel() == 0 and not singular: + continue + run_test(A, pivot, singular, fn) + + # Reproducer of a magma bug, + # see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on + # This is also a bug in cuSOLVER < 11.3 + if (dtype == torch.double + and singular + and (torch.version.cuda is None or + torch.version.cuda.split('.') >= ["11", "3"])): + A = torch.ones(batch + ms, dtype=dtype, device=device) + run_test(A, pivot, singular, fn) + + # Info should be positive for rank deficient matrices + A = torch.ones(5, 3, 3, device=device) + self.assertTrue((torch.linalg.lu_factor_ex(A, pivot=True).info >= 0).all()) if self.device_type == 'cpu': # Error checking, no pivoting variant on CPU - with self.assertRaisesRegex(RuntimeError, 'lu without pivoting is not implemented on the CPU'): + with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'): torch.lu(torch.empty(1, 2, 2), pivot=False) - else: - run_test(device, False) + + with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'): + torch.linalg.lu_factor(torch.empty(1, 2, 2), pivot=False) @skipCPUIfNoLapack @skipCUDAIfNoMagma @@ -7316,8 +7301,9 @@ def test_slogdet_errors_and_warnings(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"): torch.linalg.slogdet(a, out=(sign_out, logabsdet_out)) - @slowTest - @skipCUDAIfNoMagma + @skipCUDAIf(torch.version.cuda is not None + and torch.version.cuda.split(".") < ["11", "3"], "There's a bug in cuSOLVER < 11.3") + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(torch.double) def test_det_logdet_slogdet(self, device, dtype): @@ -7335,15 +7321,15 @@ def test_single_det(M, target, desc): # Test det self.assertEqual(det, target_sdet * target_logabsdet.exp(), - atol=1e-7, rtol=0, msg='{} (det)'.format(desc)) + atol=1e-6, rtol=0, msg='{} (det)'.format(desc)) # Test slogdet # Compare the overall value rather than individual parts because of # precision issues when det is near zero. self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), - atol=1e-7, rtol=0, msg='{} (slogdet)'.format(desc)) + atol=1e-6, rtol=0, msg='{} (slogdet)'.format(desc)) self.assertEqual(linalg_sdet * linalg_logabsdet.exp(), target_sdet * target_logabsdet.exp(), - atol=1e-7, rtol=0, msg='{} (linalg_slogdet)'.format(desc)) + atol=1e-6, rtol=0, msg='{} (linalg_slogdet)'.format(desc)) # Test logdet # Compare logdet against our own pytorch slogdet because they should @@ -7354,7 +7340,7 @@ def test_single_det(M, target, desc): self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc)) else: self.assertEqual(logdet.exp(), target_logabsdet.exp(), - atol=1e-7, rtol=0, msg='{} (logdet non-negative case)'.format(desc)) + atol=1e-6, rtol=0, msg='{} (logdet non-negative case)'.format(desc)) eye = torch.eye(5, dtype=dtype, device=device) test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity') diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index efc1e07a56812..c730991e1f492 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -19,7 +19,7 @@ '_svd_helper', 'linalg_svd', 'linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask', 'fake_quantize_per_channel_affine_cachemask', 'linalg_lstsq', 'linalg_eig', 'linalg_cholesky_ex', 'frexp', 'lu_unpack', 'histogram', '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams', - '_fused_moving_avg_obs_fq_helper', + '_fused_moving_avg_obs_fq_helper', 'linalg_lu_factor', 'linalg_lu_factor_ex', '_det_lu_based_helper', '_lu_with_info', } @@ -85,6 +85,8 @@ def test_namedtuple_return(self): op(operators=['linalg_slogdet'], input=(), names=('sign', 'logabsdet'), hasout=True), op(operators=['linalg_cholesky_ex'], input=(), names=('L', 'info'), hasout=True), op(operators=['linalg_inv_ex'], input=(), names=('inverse', 'info'), hasout=True), + op(operators=['linalg_lu_factor'], input=(), names=('LU', 'pivots'), hasout=True), + op(operators=['linalg_lu_factor_ex'], input=(), names=('LU', 'pivots', 'info'), hasout=True), op(operators=['fake_quantize_per_tensor_affine_cachemask'], input=(0.1, 0, 0, 255), names=('output', 'mask',), hasout=False), op(operators=['fake_quantize_per_channel_affine_cachemask'], diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index b59b07a92923e..1f55514ae6e0a 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -932,9 +932,10 @@ other: zeros_like(other) result: self_t.zero_() -- name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info) - self: _lu_with_info_backward(grad, self, LU, pivots) - LU: _lu_with_info_jvp(self_t, LU, pivots) +- name: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + A: lu_factor_ex_backward(grad, A, LU, pivots) + LU: lu_factor_ex_jvp(A_t, LU, pivots) + output_differentiability: [True, False, False] - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor self, LU_data: lu_solve_backward(grad, result, LU_data, LU_pivots, grad_input_mask) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 7ce2d0246976d..340b8ec87ded0 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -107,9 +107,9 @@ 'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward', 'diag', 'masked_scatter', 'masked_select', 'index_add', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub', 'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward', - 'index', 'masked_fill', 'linalg_cross', 'lu_unpack', 'renorm', '_conj_physical', + 'index', 'masked_fill', 'linalg_cross', 'lu_unpack', 'renorm', '_conj_physical', 'linalg_lu_factor_ex', 'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid', - 'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', '_lu_with_info', + 'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', 'linalg_solve_triangular', 'linalg_pinv', 'linalg_lstsq', 'col2im', 'col2im_backward', 'im2col', 'im2col_backward', } diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index d445f2a9637dc..24ebd92705b05 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -68,6 +68,14 @@ inline Tensor& householder_product_out(Tensor& result, const Tensor& input, cons return torch::linalg_householder_product_out(result, input, tau); } +inline std::tuple lu_factor(const Tensor& self, const bool pivot) { + return torch::linalg_lu_factor(self, pivot); +} + +inline std::tuple lu_factor_out(Tensor& LU, Tensor& pivots, const Tensor& self, const bool pivot) { + return torch::linalg_lu_factor_out(LU, pivots, self, pivot); +} + inline std::tuple lstsq(const Tensor& self, const Tensor& b, c10::optional cond, c10::optional driver) { return torch::linalg_lstsq(self, b, cond, driver); } @@ -341,6 +349,17 @@ inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, c10::string_v return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +/// Computes the pivoted LU factorization +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.lu_factor +inline std::tuple lu_factor(const Tensor& input, const bool pivot=true) { + return detail::lu_factor(input, pivot); +} + +inline std::tuple lu_factor_out(Tensor& LU, Tensor& pivots, const Tensor& self, const bool pivot=true) { + return detail::lu_factor_out(LU, pivots, self, pivot); +} + inline Tensor norm(const Tensor& self, const optional& opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index b8df8b3377654..5d2292303d1c8 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -4401,7 +4401,7 @@ Tensor lu_solve_jvp( // The identity permutation pivots are 1-based because of the Fortran-like LAPACK interfaces. // More details on the permutation matrix canceling note: // as part of forward AD we need to compute A^{-1} dA. - // Since A = P L U and P is not differentiable, we get + // Since A = P L U and P is locally constant for full-rank matrices, we get // dA = P d(L U), A^{-1} = (L U)^{-1} P^T, so // A^{-1} dA = (L U)^{-1} d(L U), which is lu_solve with // the pivots set to the identity permutation @@ -4611,8 +4611,8 @@ Tensor plu_backward_base( auto U_principal_H = U_principal.mH(); auto U_grad_principal = U_grad.narrow(-2, 0, k).narrow(-1, 0, k); - auto phi_L = L_principal_H.matmul(L_grad_principal).tril_(-1); - auto phi_U = U_grad_principal.matmul(U_principal_H).triu_(); + auto phi_L = L_principal_H.matmul(L_grad_principal).tril(-1); + auto phi_U = U_grad_principal.matmul(U_principal_H).triu(); auto phi = phi_L + phi_U; @@ -4621,29 +4621,21 @@ Tensor plu_backward_base( auto U_complement = U.narrow(-2, 0, k).narrow(-1, k, n - k); auto U_grad_complement = U_grad.narrow(-2, 0, k).narrow(-1, k, n - k); - // The result for X1_grad and X2_grad from above. + auto phi_complement = U_grad_complement.matmul(U_complement.mH()).tril(-1); + + // recall the result for X1_grad and X2_grad from above. // It can be rewritten as // (X1_grad | X2_grad) = P L^{-H} psi, where // psi = (psi1 | psi2) // = ([L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H} | U2_grad), // so it is filled in parts. - // - // fill psi2 in - - // phi_complement = U2_grad U2^H o 1_L - auto phi_complement = U_grad_complement.matmul(U_complement.transpose(-2, -1).conj()).tril_(-1); - // phi = [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] - phi.sub_(phi_complement); - // solve for psi1 to avoid the inversion of U1^H - Tensor psi_principal = at::linalg_solve_triangular(U_principal_H, phi, - /*upper=*/false, - /*left=*/false, - /*unitriangular=*/false); - auto psi = at::empty_like(self); - psi.narrow(-2, 0, k).narrow(-1, k, n - k).copy_(U_grad_complement); - psi.narrow(-2, 0, k).narrow(-1, 0, k).copy_(psi_principal); + auto psi_principal = at::linalg_solve_triangular(U_principal_H, phi - phi_complement, + /*upper=*/false, + /*left=*/false, + /*unitriangular=*/false); + auto psi = at::cat({psi_principal, U_grad_complement}, /*dim=*/-1); self_grad = P.matmul(at::linalg_solve_triangular(L_principal_H, psi, /*upper=*/true, @@ -4656,18 +4648,14 @@ Tensor plu_backward_base( auto L_complement = L.narrow(-2, k, m - k).narrow(-1, 0, k); auto L_grad_complement = L_grad.narrow(-2, k, m - k).narrow(-1, 0, k); - auto phi_complement = L_complement.mH().matmul(L_grad_complement).triu_(); - phi.sub_(phi_complement); + auto phi_complement = L_complement.mH().matmul(L_grad_complement).triu(); - auto psi_principal = at::linalg_solve_triangular(L_principal_H, phi, + auto psi_principal = at::linalg_solve_triangular(L_principal_H, phi - phi_complement, /*upper=*/true, /*left=*/true, /*unitriangular=*/true); - - auto psi = at::empty_like(self); - psi.narrow(-2, k, m - k).narrow(-1, 0, k).copy_(L_grad_complement); - psi.narrow(-2, 0, k).narrow(-1, 0, k).copy_(psi_principal); + auto psi = at::cat({psi_principal, L_grad_complement}, -2); self_grad = at::linalg_solve_triangular(U_principal_H, P.matmul(psi), /*upper=*/false, @@ -4678,7 +4666,7 @@ Tensor plu_backward_base( return self_grad; } -Tensor _lu_with_info_backward( +Tensor lu_factor_ex_backward( const Tensor& grad, const Tensor& self, const Tensor& LU, @@ -4692,8 +4680,8 @@ Tensor _lu_with_info_backward( return plu_backward_base({/*L_grad=*/grad, /*U_grad=*/grad}, self, P, L, U); } -Tensor _lu_with_info_jvp( - const Tensor& dX, +Tensor lu_factor_ex_jvp( + const Tensor& dA, const Tensor& LU, const Tensor& pivs ) { @@ -4707,31 +4695,20 @@ Tensor _lu_with_info_jvp( auto n = LU.size(-1); auto k = std::min(m, n); - auto pdX = P.mT().matmul(dX); + auto PdA = P.mT().matmul(dA); // similar to the backward implementation, we also consider block structures such as: // for a matrix A of size m x n we decompose it as // A = (A1 | A2) with A1 of size m x m if m <= n and // A = (A1^T | A2^T)^T with A1 of size n x n if m > n. - auto pdX1 = pdX.narrow(-2, 0, k).narrow(-1, 0, k); + auto PdA1 = PdA.narrow(-2, 0, k).narrow(-1, 0, k); auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k); auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k); - // dK = L1^{-1} pdX1 - auto dK = std::get<0>(at::triangular_solve( - pdX1, - L1, - /*upper=*/false, - /*transpose=*/false, - /*unitriangular=*/true - )); + // dK = L1^{-1} PdA1 + auto dK = at::linalg_solve_triangular(L1, PdA1, /*upper=*/false, /*left=*/true, /*unitriangular*/true); // dK <- dK U1^{-1} - dK = std::get<0>(at::triangular_solve( - dK.mT(), - U1, - /*upper=*/true, - /*transpose=*/true - )).mT(); + dK = at::linalg_solve_triangular(U1, dK, /*upper=*/true, /*left=*/false); auto dL1 = L1.matmul(dK.tril(-1)); auto dU1 = dK.triu().matmul(U1); @@ -4744,36 +4721,24 @@ Tensor _lu_with_info_jvp( return dL1 + dU1; } else { - auto dLU = at::zeros_like(LU); - dLU.narrow(-2, 0, k).narrow(-1, 0, k).copy_(dL1 + dU1); + auto dLU1 = dL1 + dU1; if (m < n) { - // we only need to update dU2 defined as - // dU2 := L1^{-1} (pdX2 - dL1 U2) - auto pdX2 = pdX.narrow(-1, k, n - k); + // we only need to update dLU2 defined as + // dLU2 := L1^{-1} PdA2 - dK.tril(-1) U2 + auto PdA2 = PdA.narrow(-1, k, n - k); auto U2 = U.narrow(-1, k, n - k); - dLU.narrow(-1, k, n - k).copy_(std::get<0>(at::triangular_solve( - pdX2 - dL1.matmul(U2), - L1, - /*upper=*/false, - /*transpose=*/false, - /*unitriangular=*/true - ))); + auto dLU2 = at::linalg_solve_triangular(L1, PdA2, /*upper=*/false, /*left=*/true, /*unitriangular*/true) - dK.tril(-1).matmul(U2); + return at::cat({dLU1, dLU2}, /*dim=*/-1); } else { - // we only need to update dL2 defined as - // dL2 := (pdX2 - L2 dU1) U1^{-1} - auto pdX2 = pdX.narrow(-2, k, m - k); + // we only need to update dLU2 defined as + // dLU2 := PdA2 U1^{-1} - L2 dK.triu() + auto PdA2 = PdA.narrow(-2, k, m - k); auto L2 = L.narrow(-2, k, m - k); - dLU.narrow(-2, k, m - k).copy_(std::get<0>(at::triangular_solve( - (pdX2 - L2.matmul(dU1)).mT(), - U1, - /*upper=*/true, - /*transpose=*/true - )).mT()); + auto dLU2 = at::linalg_solve_triangular(U1, PdA2, /*upper=*/true, /*left=*/false) - L2.matmul(dK.triu()); + return at::cat({dLU1, dLU2}, /*dim=*/-2); } - - return dLU; } } diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index ba75abc114fc0..4fa04aafb1ddf 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -368,13 +368,13 @@ Tensor lu_backward_base( const Tensor& L, const Tensor& U ); -Tensor _lu_with_info_backward( +Tensor lu_factor_ex_backward( const Tensor& grad, const Tensor& self, const Tensor& LU, const Tensor& pivs ); -Tensor _lu_with_info_jvp( +Tensor lu_factor_ex_jvp( const Tensor& dX, const Tensor& LU, const Tensor& pivs diff --git a/torch/functional.py b/torch/functional.py index d4ae1f2ccd19b..cbe0d6f7dc7d1 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1648,9 +1648,10 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): ``True``. .. note:: - * The pivots returned by the function are 1-indexed. If :attr:`pivot` - is ``False``, then the returned pivots is a tensor filled with - zeros of the appropriate size. + * The returned permutation matrix for every matrix in the batch is + represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``. + ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm, + the ``i``-th row was permuted with the ``j-1``-th row. * LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is available for diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 801363c48f3c9..6b8506dd0fa80 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -9,7 +9,11 @@ Tensor = torch.Tensor common_notes = { - "sync_note": """When inputs are on a CUDA device, this function synchronizes that device with the CPU.""" + "experimental_warning": """This function is "experimental" and it may change in a future PyTorch release.""", + "sync_note": "When inputs are on a CUDA device, this function synchronizes that device with the CPU.", + "sync_note_ex": r"When the inputs are on a CUDA device, this function synchronizes only when :attr:`check_errors`\ `= True`.", + "sync_note_has_ex": ("When inputs are on a CUDA device, this function synchronizes that device with the CPU. " + "For a version of this function that does not synchronize, see :func:`{}`.") } @@ -163,9 +167,11 @@ ``info`` filled with zeros indicates that the decomposition was successful. If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. -.. note:: If :attr:`A` is on a CUDA device, this function may synchronize that device with the CPU. +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} -.. warning:: This function is "experimental" and it may change in a future PyTorch release. +.. warning:: {common_notes["experimental_warning"]} +""" + r""" .. seealso:: :func:`torch.linalg.cholesky` is a NumPy compatible variant that always checks for errors. @@ -292,11 +298,11 @@ Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. -.. note:: - If :attr:`A` is on a CUDA device then this function may synchronize - that device with the CPU. +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} -.. warning:: This function is "experimental" and it may change in a future PyTorch release. +.. warning:: {common_notes["experimental_warning"]} +""" + r""" .. seealso:: @@ -2021,6 +2027,112 @@ https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem """) +lu_factor = _add_docstr(_linalg.linalg_lu_factor, r""" +linalg.lu_factor(A, *, bool pivot=True, out=None) -> (Tensor, Tensor) + +Computes a compact representation of the LU factorization with partial pivoting of a matrix. + +This function computes a compact representation of the decomposition given by :func:`torch.linalg.lu`. +If the matrix is square, this representation may be used in :func:`torch.linalg.lu_solve` +to solve system of linear equations that share the matrix :attr:`A`. + +The returned decomposition is represented as a named tuple `(LU, pivots)`. +The ``LU`` matrix has the same shape as the input matrix ``A``. Its upper and lower triangular +parts encode the non-constant elements of ``L`` and ``U`` of the LU decomposition of ``A``. + +The returned permutation matrix is represented by a 1-indexed vector. `pivots[i] == j` represents +that in the `i`-th step of the algorithm, the `i`-th row was permuted with the `j-1`-th row. + +On CUDA, one may use :attr:`pivot`\ `= False`. In this case, this function returns the LU +decomposition without pivoting if it exists. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.lu_factor_ex")} +""" + r""" +.. warning:: The LU decomposition is almost never unique, as often there are different permutation + matrices that can yield different LU decompositions. + As such, different platforms, like SciPy, or inputs on different devices, + may produce different valid decompositions. + +.. warning:: Gradient computations are only supported if the input matrix is full-rank. + If this condition is not met, no error will be thrown, but the gradient may not be finite. + This is because the LU decomposition with pivoting is not differentiable at these points. + +.. seealso:: + + :func:`torch.linalg.lu_solve` solves a system of linear equations given the output of this + function provided the input matrix was square and invertible. + + :func:`torch.linalg.lu` computes the LU decomposition with partial pivoting of a possibly + non-square matrix. + + :func:`torch.linalg.solve` solves a system of linear equations. It is a composition + of :func:`~lu_factor` and :func:`~lu_solve`. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU + decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots)`. + +Raises: + RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` + is not invertible. + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> B1 = torch.randn(2, 3, 4) + >>> B2 = torch.randn(2, 3, 7) + >>> A_factor = torch.linalg.lu_factor(A) + >>> X1 = torch.linalg.lu_solve(A_factor, B1) + >>> X2 = torch.linalg.lu_solve(A_factor, B2) + >>> torch.allclose(A @ X1, B1) + True + >>> torch.allclose(A @ X2, B2) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +lu_factor_ex = _add_docstr(_linalg.linalg_lu_factor_ex, r""" +linalg.lu_factor_ex(A, *, pivot=True, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) + +This is a version of :func:`~lu_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU + decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. + check_errors (bool, optional): controls whether to check the content of ``infos`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots, info)`. + +.. _LAPACK's getrf: + https://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html +""") tensorinv = _add_docstr(_linalg.linalg_tensorinv, r""" linalg.tensorinv(A, ind=2, *, out=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index 60f2b2ae1641f..4f81399350055 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -617,6 +617,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.masked_scatter: lambda input, mask, source: -1, torch.masked_select: lambda input, mask, out=None: -1, torch.matmul: lambda input, other, out=None: -1, + torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1, + torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1, torch.linalg.matmul: lambda input, other, out=None: -1, # alias for torch.matmul torch.matrix_power: lambda input, n: -1, torch.linalg.matrix_power: lambda input, n, out=None: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3b059e4bbc84f..7433894a65be3 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1306,7 +1306,7 @@ def sample_inputs_linalg_det(op_info, device, dtype, requires_grad): t.requires_grad = requires_grad return [SampleInput(t) for t in inputs] -def sample_inputs_linalg_det_singular(op_info, device, dtype, requires_grad): +def sample_inputs_linalg_det_singular(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) def make_singular_matrix_batch_base(size, rank): @@ -1319,13 +1319,13 @@ def make_singular_matrix_batch_base(size, rank): b = make_arg(size[:-2] + (rank, n)) / 10 x = a @ b - lu, pivs = x.lu() + lu, pivs, _ = torch.linalg.lu_factor_ex(x) p, l, u = torch.lu_unpack(lu, pivs) u_diag_abs = u.diagonal(0, -2, -1).abs() u_diag_abs_largest = u_diag_abs.max(dim=-1, keepdim=True).values u_diag_abs_smallest_idxs = torch.topk(u_diag_abs, k=(n - rank), largest=False).indices u.diagonal(0, -2, -1).div_(u_diag_abs_largest) - u.diagonal(0, -2, -1)[..., u_diag_abs_smallest_idxs] = torch.finfo(dtype).eps + u[..., u_diag_abs_smallest_idxs] = torch.finfo(dtype).eps matrix = p @ l @ u @@ -4449,6 +4449,7 @@ def sample_inputs_take(op_info, device, dtype, requires_grad): # Empty cases src_sizes = [(0,), (), (1,), (3, 2)] src_gen = (make_arg(size) for size in src_sizes) + idx = make_idx((0,), high=1) for src in src_gen: yield SampleInput(input=src.detach().clone().requires_grad_(requires_grad), @@ -5304,37 +5305,66 @@ def sample_inputs_cholesky_solve(op_info, device, dtype, requires_grad=False, ** def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_fullrank_matrices_with_distinct_singular_values, + dtype=dtype, device=device, requires_grad=requires_grad) + # not needed once OpInfo tests support Iterables batch_shapes = ((), (3,), (3, 3)) for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)): shape = batch_shape + (S + size_delta, S) - input = make_tensor(shape, device, dtype, requires_grad=requires_grad, low=None, high=None) + input = make_arg(*shape) yield SampleInput(input, args=(True, get_infos)) +def sample_inputs_linalg_lu_factor(op_info, device, dtype, requires_grad=False, **kwargs): + # When calling `lu_factor` we need to assure that the matrix is invertible + make_fn = make_tensor if "ex" in op_info.name else make_fullrank_matrices_with_distinct_singular_values + make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad) + + # not needed once OpInfo tests support Iterables + def generate_samples(): + batch_shapes = ((), (3,), (3, 3)) + # pivot=False only supported in CUDA + pivots = (True, False) if torch.device(device).type == "cuda" else (True,) + deltas = (-2, -1, 0, +1, +2) + for batch_shape, pivot, delta in product(batch_shapes, pivots, deltas): + shape = batch_shape + (S + delta, S) + # Insanely annoying that make_fullrank_blablabla accepts a *shape and not a tuple! + A = make_arg(shape) if "ex" in op_info.name else make_arg(*shape) + yield SampleInput(A, kwargs={"pivot": pivot}) + + return list(generate_samples()) def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + make_fn = make_fullrank_matrices_with_distinct_singular_values + make_a = partial(make_fn, dtype=dtype, device=device) + make_b = partial(make_tensor, dtype=dtype, device=device) - batches = [(), (0, ), (2, )] - ns = [5, 3, 0] - nrhs = [0, 1, 6] + batches = ((), (0, ), (2, )) + ns = (5, 3, 0) + nrhs = (0, 1, 6) for n, batch, rhs in product(ns, batches, nrhs): - a = random_fullrank_matrix_distinct_singular_value(n, *batch, dtype=dtype, device=device) - requires_grad_options = (False,) if not requires_grad else (True, False) + with torch.no_grad(): + shape_a = batch + (n, n) + a = make_a(*shape_a) + lu, pivs = a.lu() + lu = lu.contiguous() + + shape_b = batch + (n, rhs) + b = make_b(shape_b) + + grads = (False,) if not requires_grad else (True, False) # we try all possible combinations of requires_grad for each input - for lu_requires_grad, b_requires_grad in product(requires_grad_options, requires_grad_options): + for lu_grad, b_grad in product(grads, grads): # when requires_grad == True, at least one input has to have requires_grad enabled - if requires_grad and not lu_requires_grad and not b_requires_grad: + if requires_grad and not lu_grad and not b_grad: continue - # we run LU several times to guarantee that the produced SampleInputs are independent - # this is especially important when setting different requries_grad for same tensors! - lu, pivs = a.lu() - lu.requires_grad = lu_requires_grad - b = torch.randn(*batch, n, rhs, dtype=dtype, device=device) - b.requires_grad = b_requires_grad - yield SampleInput(b, args=(lu, pivs)) + lu_ = lu.detach().clone() + lu_.requires_grad_(lu_grad) + b_ = b.detach().clone() + b_.requires_grad_(b_grad) + yield SampleInput(b_, args=(lu_, pivs)) def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs): # not needed once OpInfo tests support Iterables @@ -7404,8 +7434,10 @@ def sample_inputs_softplus(op_info, device, dtype, requires_grad, **kwargs): ] def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs): + make_arg = make_fullrank_matrices_with_distinct_singular_values + def make_input(): - return make_fullrank_matrices_with_distinct_singular_values(12, 12, device=device, dtype=dtype) + return make_arg(12, 12, device=device, dtype=dtype, requires_grad=requires_grad) # lhs / rhs shape can have any number of dimensions as long as their product equals 12 shapes = [ @@ -10127,13 +10159,29 @@ def ref_pairwise_distance(input1, input2): dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), supports_autograd=False, sample_inputs_func=sample_inputs_comparison_ops), + OpInfo('linalg.lu_factor', + aten_name='linalg_lu_factor', + op=torch.linalg.lu_factor, + dtypes=floating_and_complex_types(), + check_batched_gradgrad=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_linalg_lu_factor, + skips=( + # Call to .item() in checkErrors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_composite_compliance'), + ), + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack]), + OpInfo('linalg.lu_factor_ex', + aten_name='linalg_lu_factor_ex', + op=torch.linalg.lu_factor_ex, + dtypes=floating_and_complex_types(), + check_batched_gradgrad=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_linalg_lu_factor, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack]), OpInfo('lu', op=torch.lu, dtypes=floating_and_complex_types(), - supports_inplace_autograd=False, - # we use in-place operations which cannot be avoided. - # This causes vmap failures, hence we skip batched gradient checks - check_batched_grad=False, check_batched_gradgrad=False, supports_forward_ad=True, supports_fwgrad_bwgrad=False, # need: lu_unpack @@ -10141,7 +10189,7 @@ def ref_pairwise_distance(input1, input2): check_batched_forward_grad=False, supports_out=False, sample_inputs_func=sample_inputs_lu, - decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack], + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], skips=( # we skip jit tests because `lu` is a torch function # RuntimeError: diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 0abf713fd97e2..a4643f3ac8b98 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2660,16 +2660,17 @@ def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims, # Creates a full rank matrix with distinct signular values or # a batch of such matrices -# Shape must be a square matrix or batch of square matrices -def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype): - assert shape[-1] == shape[-2] - t = make_tensor(shape, device=device, dtype=dtype) - u, _, vh = torch.linalg.svd(t, full_matrices=False) - # TODO: improve the handling of complex tensors here - real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype - s = torch.arange(1., shape[-1] + 1, dtype=real_dtype, device=device).mul_(1.0 / (shape[-1] + 1)) - return (u * s.to(dtype)) @ vh - +def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype, requires_grad=False): + with torch.no_grad(): + t = make_tensor(shape, device=device, dtype=dtype) + u, _, vh = torch.linalg.svd(t, full_matrices=False) + # TODO: improve the handling of complex tensors here + real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype + k = min(shape[-1], shape[-2]) + s = torch.arange(1., k + 1, dtype=real_dtype, device=device).mul_(1.0 / (k + 1)) + x = (u * s.to(dtype)) @ vh + x.requires_grad_(requires_grad) + return x def random_matrix(rows, columns, *batch_dims, **kwargs): """Return rectangular matrix or batches of rectangular matrices. @@ -2687,6 +2688,8 @@ def random_matrix(rows, columns, *batch_dims, **kwargs): return torch.ones(rows, columns, dtype=dtype, device=device) A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device) + if A.numel() == 0: + return A u, _, vh = torch.linalg.svd(A, full_matrices=False) k = min(rows, columns) s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device)