diff --git a/src/sparse_direct/numeric/LeastSquares.cpp b/src/sparse_direct/numeric/LeastSquares.cpp index 263fa1a00f..521804cff2 100644 --- a/src/sparse_direct/numeric/LeastSquares.cpp +++ b/src/sparse_direct/numeric/LeastSquares.cpp @@ -24,32 +24,59 @@ void LeastSquares if( orientation != NORMAL && A.Width() != Y.Height() ) LogicError("Width of A and height of Y must match"); ) - if( A.Width() > A.Height() ) - LogicError("LeastSquares currently assumes height(A) >= width(A)"); + const Int m = A.Height(); + const Int n = A.Width(); DistSparseMatrix C(A.Comm()); + X.SetComm( Y.Comm() ); if( orientation == NORMAL ) { - const Int n = A.Width(); - Herk( LOWER, ADJOINT, Base(1), A, C ); - MakeHermitian( LOWER, C ); - X.SetComm( Y.Comm() ); Zeros( X, n, Y.Width() ); - Multiply( ADJOINT, F(1), A, Y, F(0), X ); + if( m >= n ) + { + Herk( LOWER, ADJOINT, Base(1), A, C ); + MakeHermitian( LOWER, C ); + + Multiply( ADJOINT, F(1), A, Y, F(0), X ); + HermitianSolve( C, X, ctrl ); + } + else + { + Herk( LOWER, NORMAL, Base(1), A, C ); + MakeHermitian( LOWER, C ); + + DistMultiVec YCopy(Y.Comm()); + YCopy = Y; + HermitianSolve( C, YCopy, ctrl ); + Multiply( ADJOINT, F(1), A, YCopy, F(0), X ); + } } else if( orientation == ADJOINT || !IsComplex::val ) { - const Int n = A.Height(); - Herk( LOWER, NORMAL, Base(1), A, C ); - MakeHermitian( LOWER, C ); - X.SetComm( Y.Comm() ); - Zeros( X, n, Y.Width() ); - Multiply( NORMAL, F(1), A, Y, F(0), X ); + Zeros( X, m, Y.Width() ); + if( m >= n ) + { + Herk( LOWER, NORMAL, Base(1), A, C ); + MakeHermitian( LOWER, C ); + + Multiply( NORMAL, F(1), A, Y, F(0), X ); + HermitianSolve( C, X, ctrl ); + } + else + { + Herk( LOWER, ADJOINT, Base(1), A, C ); + MakeHermitian( LOWER, C ); + + DistMultiVec YCopy(Y.Comm()); + YCopy = Y; + HermitianSolve( C, YCopy, ctrl ); + Multiply( NORMAL, F(1), A, YCopy, F(0), X ); + } } else { LogicError("Complex transposed option not yet supported"); } - HermitianSolve( C, X, ctrl ); + } #define PROTO(F) \ diff --git a/src/sparse_direct/numeric/Ridge.cpp b/src/sparse_direct/numeric/Ridge.cpp index a94c074de7..6473d7b96c 100644 --- a/src/sparse_direct/numeric/Ridge.cpp +++ b/src/sparse_direct/numeric/Ridge.cpp @@ -21,17 +21,32 @@ void Ridge if( A.Height() != Y.Height() ) LogicError("Heights of A and Y must match"); ) - if( A.Width() > A.Height() ) - LogicError("Ridge currently assumes height(A) >= width(A)"); + const Int m = A.Height(); const Int n = A.Width(); DistSparseMatrix C(A.Comm()); - Herk( LOWER, ADJOINT, Base(1), A, C ); - UpdateDiagonal( C, F(alpha*alpha) ); - MakeHermitian( LOWER, C ); + X.SetComm( Y.Comm() ); Zeros( X, n, Y.Width() ); - Multiply( ADJOINT, F(1), A, Y, F(0), X ); - HermitianSolve( C, X, ctrl ); + if( m >= n ) + { + Herk( LOWER, ADJOINT, Base(1), A, C ); + UpdateDiagonal( C, F(alpha*alpha) ); + MakeHermitian( LOWER, C ); + + Multiply( ADJOINT, F(1), A, Y, F(0), X ); + HermitianSolve( C, X, ctrl ); + } + else + { + Herk( LOWER, NORMAL, Base(1), A, C ); + UpdateDiagonal( C, F(alpha*alpha) ); + MakeHermitian( LOWER, C ); + + DistMultiVec YCopy(Y.Comm()); + YCopy = Y; + HermitianSolve( C, YCopy, ctrl ); + Multiply( ADJOINT, F(1), A, YCopy, F(0), X ); + } } #define PROTO(F) \ diff --git a/src/sparse_direct/numeric/Tikhonov.cpp b/src/sparse_direct/numeric/Tikhonov.cpp index 4d881bab81..f3edbc768b 100644 --- a/src/sparse_direct/numeric/Tikhonov.cpp +++ b/src/sparse_direct/numeric/Tikhonov.cpp @@ -21,17 +21,32 @@ void Tikhonov if( A.Height() != Y.Height() ) LogicError("Heights of A and Y must match"); ) - if( A.Width() > A.Height() ) - LogicError("Tikhonov currently assumes height(A) >= width(A)"); + const Int m = A.Height(); const Int n = A.Width(); DistSparseMatrix C(A.Comm()); - Herk( LOWER, ADJOINT, Base(1), A, C ); - Herk( LOWER, ADJOINT, Base(1), Gamma, Base(1), C ); - MakeHermitian( LOWER, C ); + X.SetComm( Y.Comm() ); Zeros( X, n, Y.Width() ); - Multiply( ADJOINT, F(1), A, Y, F(0), X ); - HermitianSolve( C, X, ctrl ); + if( m >= n ) + { + Herk( LOWER, ADJOINT, Base(1), A, C ); + Herk( LOWER, ADJOINT, Base(1), Gamma, Base(1), C ); + MakeHermitian( LOWER, C ); + + Multiply( ADJOINT, F(1), A, Y, F(0), X ); + HermitianSolve( C, X, ctrl ); + } + else + { + Herk( LOWER, NORMAL, Base(1), A, C ); + Herk( LOWER, NORMAL, Base(1), Gamma, Base(1), C ); + MakeHermitian( LOWER, C ); + + DistMultiVec YCopy(Y.Comm()); + YCopy = Y; + HermitianSolve( C, YCopy, ctrl ); + Multiply( ADJOINT, F(1), A, YCopy, F(0), X ); + } } #define PROTO(F) \