-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathlow_rank.cpp
50 lines (43 loc) · 1.83 KB
/
low_rank.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <skylark.hpp>
/*******************************************/
namespace skyml = skylark::ml;
/*******************************************/
int main(int argc, char* argv[]) {
El::Environment env(argc, argv);
try {
const El::Unsigned NROW = El::Input("-d","number of rows/cols", 16);
const El::Unsigned NCOL = NROW;
const El::Unsigned M = El::Input("-m","sketching dimension", 8);
const El::Unsigned ML = El::Input("-l", "ml param", 4);
const El::Unsigned MR = El::Input("-r", "mr param", 4);
const El::Unsigned K = El::Input("-k", "k param", 2);
const double EPS = El::Input("-e", "epsilon param", 0.12);
const bool sym = El::Input("-s","Apply Symmetric only", false);
El::ProcessInput();
El::DistMatrix<double> A(NROW, NCOL);
El::Uniform(A, NROW, NCOL);
El::DistMatrix<double> Z(A.Height(), M);
skyml::low_rank_t<double> transformer (K, EPS, M, ML, MR);
if (sym) {
// Apply for sym
skyml::low_rank_sym_t<double> ret = transformer.apply_symmetric(A);
El::DistMatrix<double, El::CIRC, El::CIRC> ZU(ret.ZU);
El::DistMatrix<double, El::CIRC, El::CIRC> D(ret.D);
El::mpi::Barrier();
if (El::mpi::Rank() == 0) {
El::Output("Printing after apply_symmetric:");
El::Print(ZU, "ZU: ");
El::Print(D, "D: ");
}
} else {
El::Unsigned nprocs = El::mpi::Size(El::mpi::COMM_WORLD);
// Apply PSD
transformer.apply_PSD(A);
El::DistMatrix<double, El::CIRC, El::CIRC> a(A);
if (El::mpi::Rank() == 0) {
El::Print(a, "After apply_PSD: ");
}
}
} catch(std::exception& e) { El::ReportException(e); }
return EXIT_SUCCESS;
}