Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 202 additions & 1 deletion source/module_base/scalapack_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ extern "C"
int *desc,
const int *m, const int *n, const int *mb, const int *nb, const int *irsrc, const int *icsrc,
const int *ictxt, const int *lld, int *info);

void pddot_(int* n, double* dot, double* x, int* ix, int* jx, int* descx, int* incx,
double* y, int* iy, int* jy, int* descy, int* incy);
void pzdotc_(int* n, std::complex<double>* dot, std::complex<double>* x, int* ix, int* jx, int* descx, int* incx,
std::complex<double>* y, int* iy, int* jy, int* descy, int* incy);

void pdpotrf_(char *uplo, int *n, double *a, int *ia, int *ja, int *desca, int *info);
// void pzpotrf_(char *uplo, int *n, double _Complex *a, int *ia, int *ja, int *desca, int *info);
Expand Down Expand Up @@ -69,7 +74,10 @@ extern "C"
void pztrmm_(char *side , char *uplo , char *transa , char *diag , int *m , int *n ,
std::complex<double> *alpha , std::complex<double> *a , int *ia , int *ja , int *desca ,
std::complex<double> *b , int *ib , int *jb , int *descb );

void pzhemm_(char* side , char* uplo , int* m , int* n ,
std::complex<double>* alpha , std::complex<double>* a , int* ia , int* ja , int* desca ,
std::complex<double>* b , int* ib , int* jb , int* descb ,
std::complex<double>* beta , std::complex<double>* c , int* ic , int* jc , int* descc );
void pzgetrf_(
const int *M, const int *N,
std::complex<double> *A, const int *IA, const int *JA, const int *DESCA,
Expand Down Expand Up @@ -200,6 +208,38 @@ class ScalapackConnector
pzgeadd_(&transa, &m, &n, &alpha, a, &ia, &ja, desca, &beta, c, &ic, &jc, descc);
}

static inline
void dot(int n,
double& dot,
double* a,
int ia,
int ja,
int inca,
double* b,
int ib,
int jb,
int incb,
int* desc)
{
pddot_(&n, &dot, a, &ia, &ja, desc, &inca, b, &ib, &jb, desc, &incb);
}

static inline
void dot(int n,
std::complex<double>& dotc,
std::complex<double>* a,
int ia,
int ja,
int inca,
std::complex<double>* b,
int ib,
int jb,
int incb,
int* desc)
{
pzdotc_(&n, &dotc, a, &ia, &ja, desc, &inca, b, &ib, &jb, desc, &incb);
}

static inline
void gemm(
const char transa, const char transb,
Expand Down Expand Up @@ -228,6 +268,85 @@ class ScalapackConnector
B, &IB, &JB, DESCB, &beta, C, &IC, &JC, DESCC);
}

static inline
void gemm(char transa, char transb, int M, int N, int K,
double alpha,
double* A,
double* B,
double beta,
double* C,
int* DESC)
{
int isrc = 1;
pdgemm_(&transa,
&transb,
&M,
&N,
&K,
&alpha,
A,
&isrc,
&isrc,
DESC,
B,
&isrc,
&isrc,
DESC,
&beta,
C,
&isrc,
&isrc,
DESC);
}

static inline
void gemm(char transa, char transb, int M, int N, int K,
std::complex<double> alpha,
std::complex<double>* A,
std::complex<double>* B,
std::complex<double> beta,
std::complex<double>* C,
int* DESC)
{

int isrc = 1;
pzgemm_(&transa,
&transb,
&M,
&N,
&K,
&alpha,
A,
&isrc,
&isrc,
DESC,
B,
&isrc,
&isrc,
DESC,
&beta,
C,
&isrc,
&isrc,
DESC);
}

static inline
void symm(char side,
char uplo,
int m,
int n,
double alpha,
double* a,
double* b,
double beta,
double* c,
int* desc)
{
int isrc = 1;
pdsymm_(&side, &uplo, &m, &n, &alpha, a, &isrc, &isrc, desc, b, &isrc, &isrc, desc, &beta, c, &isrc, &isrc, desc);
}

static inline
void getrf(
const int M, const int N,
Expand Down Expand Up @@ -263,6 +382,88 @@ class ScalapackConnector
{
pztranu_(&m, &n, &alpha, a, &ia, &ja, desca, &beta, c, &ic, &jc, descc);
}

static inline
int potrf(char uplo, int na, double* U, int* desc)
{
int isrc = 1;
int info;
pdpotrf_(&uplo, &na, U, &isrc, &isrc, desc, &info);
return info;
}

static inline
int potrf(char uplo, int na, std::complex<double>* U, int* desc)
{
int isrc = 1;
int info;
pzpotrf_(&uplo, &na, U, &isrc, &isrc, desc, &info);
return info;
}

static inline
void trmm(char side,
char uplo,
char trans,
char diag,
int m,
int n,
double alpha,
double* a,
double* b,
int* desc)
{
int isrc = 1;
pdtrmm_(&side, &uplo, &trans, &diag, &m, &n, &alpha, a, &isrc, &isrc, desc, b, &isrc, &isrc, desc);
}

static inline
void trmm(char side,
char uplo,
char trans,
char diag,
int m,
int n,
std::complex<double> alpha,
std::complex<double>* a,
std::complex<double>* b,
int* desc)
{
int isrc = 1;
pztrmm_(&side, &uplo, &trans, &diag, &m, &n, &alpha, a, &isrc, &isrc, desc, b, &isrc, &isrc, desc);
}

static inline
void hemm(char side,
char uplo,
int na,
std::complex<double> alpha,
std::complex<double>* a,
std::complex<double>* b,
std::complex<double> beta,
std::complex<double>* c,
int* desc)
{
int isrc = 1;
pzhemm_(&side,
&uplo,
&na,
&na,
&alpha,
a,
&isrc,
&isrc,
desc,
b,
&isrc,
&isrc,
desc,
&beta,
c,
&isrc,
&isrc,
desc);
}
};

#endif // __MPI
Expand Down
5 changes: 4 additions & 1 deletion source/module_hsolver/genelpa/elpa_new.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include "elpa_new.h"

#include "elpa_solver.h"
#include "my_math.hpp"
extern "C"
{
#include "Cblacs.h"
}
#include "utils.h"
#include <cfloat>
#include <complex>
Expand Down
38 changes: 19 additions & 19 deletions source/module_hsolver/genelpa/elpa_new_complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "elpa_new.h"
#include "elpa_solver.h"

#include "my_math.hpp"
#include "module_base/scalapack_connector.h"
#include "utils.h"

extern std::map<int, elpa_t> NEW_ELPA_HANDLE_POOL;
Expand Down Expand Up @@ -72,7 +72,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
t=-1;
timer(myid, "A*U^-1", "2.1a", t);
}
Cpzgemm('C', 'N', nFull, 1.0, A, B, 0.0, zwork.data(), desc);
ScalapackConnector::gemm('C', 'N', nFull, nFull, nFull, 1.0, A, B, 0.0, zwork.data(), desc);
if(loglevel>1)
{
timer(myid, "A*U^-1", "2.1a", t);
Expand All @@ -84,7 +84,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
t=-1;
timer(myid, "U^-T*(A*U^-1)", "2.2a", t);
}
Cpzgemm('C', 'N', nFull, 1.0, B, zwork.data(), 0.0, A, desc);
ScalapackConnector::gemm('C', 'N', nFull, nFull, nFull, 1.0, B, zwork.data(), 0.0, A, desc);
if(loglevel>1)
{
timer(myid, "U^-T*(A*U^-1)", "2.2a", t);
Expand All @@ -98,7 +98,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
t=-1;
timer(myid, "B*A^T", "2.1b", t);
}
Cpzgemm('N', 'C', nFull, 1.0, B, A, 0.0, zwork.data(), desc);
ScalapackConnector::gemm('N', 'C', nFull, nFull, nFull, 1.0, B, A, 0.0, zwork.data(), desc);
if(loglevel>1)
{
timer(myid, "B*A^T", "2.1b", t);
Expand All @@ -109,7 +109,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
t=-1;
timer(myid, "B*(B*A^T)^T", "2.2b", t);
}
Cpzgemm('N', 'C', nFull, 1.0, B, zwork.data(), 0.0, A, desc);
ScalapackConnector::gemm('N', 'C', nFull, nFull, nFull, 1.0, B, zwork.data(), 0.0, A, desc);
if(loglevel>1)
{
timer(myid, "B*(B*A^T)^T", "2.2b", t);
Expand Down Expand Up @@ -168,7 +168,7 @@ int ELPA_Solver::decomposeRightMatrix(std::complex<double>* B, double* EigenValu
t=-1;
timer(myid, "pzpotrf_", "1", t);
}
Cpzpotrf('U', nFull, B, desc);
ScalapackConnector::potrf('U', nFull, B, desc);
if(loglevel>1)
{
timer(myid, "pzpotrf_", "1", t);
Expand Down Expand Up @@ -214,7 +214,7 @@ int ELPA_Solver::decomposeRightMatrix(std::complex<double>* B, double* EigenValu
t=-1;
timer(myid, "pzpotrf_", "2", t);
}
Cpzpotrf('U', nFull, B, desc);
ScalapackConnector::potrf('U', nFull, B, desc);
if(loglevel>1)
{
timer(myid, "pzpotrf_", "2", t);
Expand Down Expand Up @@ -290,7 +290,7 @@ int ELPA_Solver::decomposeRightMatrix(std::complex<double>* B, double* EigenValu
t=-1;
timer(myid, "qevq=qev*q^T", "2", t);
}
Cpzgemm('N', 'C', nFull, 1.0, zwork.data(), EigenVector, 0.0, B, desc);
ScalapackConnector::gemm('N', 'C', nFull, nFull, nFull, 1.0, zwork.data(), EigenVector, 0.0, B, desc);
if(loglevel>1)
{
timer(myid, "qevq=qev*q^T", "2", t);
Expand All @@ -310,7 +310,7 @@ int ELPA_Solver::composeEigenVector(int DecomposedState, std::complex<double>* B
t=-1;
timer(myid, "Cpztrmm", "1", t);
}
Cpztrmm('L', 'U', 'N', 'N', nFull, nev, 1.0, B, EigenVector, desc);
ScalapackConnector::trmm('L', 'U', 'N', 'N', nFull, nev, 1.0, B, EigenVector, desc);
if(loglevel>1)
{
timer(myid, "Cpztrmm", "1", t);
Expand All @@ -322,7 +322,7 @@ int ELPA_Solver::composeEigenVector(int DecomposedState, std::complex<double>* B
t=-1;
timer(myid, "Cpzgemm", "1", t);
}
Cpzgemm('C', 'N', nFull, nev, nFull, 1.0, B, zwork.data(), 0.0, EigenVector, desc);
ScalapackConnector::gemm('C', 'N', nFull, nev, nFull, 1.0, B, zwork.data(), 0.0, EigenVector, desc);
if(loglevel>1)
{
timer(myid, "Cpzgemm", "1", t);
Expand Down Expand Up @@ -368,19 +368,19 @@ void ELPA_Solver::verify(std::complex<double>* A, double* EigenValue, std::compl
}

// R=V*D
Cpzhemm('R', 'U', nFull, 1.0, D, V, 0.0, R, desc);
ScalapackConnector::hemm('R', 'U', nFull, 1.0, D, V, 0.0, R, desc);
if(loglevel>2) saveMatrix("VD.dat", nFull, R, desc, cblacs_ctxt);
// R=A*V-V*D=A*V-R
Cpzhemm('L', 'U', nFull, 1.0, A, V, -1.0, R, desc);
ScalapackConnector::hemm('L', 'U', nFull, 1.0, A, V, -1.0, R, desc);
if(loglevel>2) saveMatrix("AV-VD.dat", nFull, R, desc, cblacs_ctxt);
// calculate the maximum and mean value of sum_i{R(:,i)*R(:,i)}
double sumError=0;
maxError=0;
for(int i=1; i<=nev; ++i)
{
std::complex<double> E;
Cpzdotc(nFull, E, R, 1, i, 1,
R, 1, i, 1, desc);
ScalapackConnector::dot(nFull, E, R, 1, i, 1,
R, 1, i, 1, desc);
double abs_E=std::abs(E);
sumError+=abs_E;
maxError=std::max(maxError, abs_E);
Expand Down Expand Up @@ -427,22 +427,22 @@ void ELPA_Solver::verify(std::complex<double>* A, std::complex<double>* B,
}

// zwork=B*V
Cpzhemm('L', 'U', nFull, 1.0, B, V, 0.0, zwork.data(), desc);
ScalapackConnector::hemm('L', 'U', nFull, 1.0, B, V, 0.0, zwork.data(), desc);
if(loglevel>2) saveMatrix("BV.dat", nFull, zwork.data(), desc, cblacs_ctxt);
// R=B*V*D=zwork*D
Cpzhemm('R', 'U', nFull, 1.0, D, zwork.data(), 0.0, R, desc);
ScalapackConnector::hemm('L', 'U', nFull, 1.0, B, V, 0.0, zwork.data(), desc);
if(loglevel>2) saveMatrix("BVD.dat", nFull, R, desc, cblacs_ctxt);
// R=A*V-B*V*D=A*V-R
Cpzhemm('L', 'U', nFull, 1.0, A, V, -1.0, R, desc);
ScalapackConnector::hemm('L', 'U', nFull, 1.0, B, V, 0.0, zwork.data(), desc);
if(loglevel>2) saveMatrix("AV-BVD.dat", nFull, R, desc, cblacs_ctxt);
// calculate the maximum and mean value of sum_i{R(:,i)*R(:,i)}
double sumError=0;
maxError=0;
for(int i=1; i<=nev; ++i)
{
std::complex<double> E;
Cpzdotc(nFull, E, R, 1, i, 1,
R, 1, i, 1, desc);
ScalapackConnector::dot(nFull, E, R, 1, i, 1,
R, 1, i, 1, desc);
double abs_E=std::abs(E);
sumError+=abs_E;
maxError=std::max(maxError, abs_E);
Expand Down
Loading
Loading