Skip to content

Commit be17c04

Browse files
djelovinaDenis JelovinaVladimir Dimic
authored
532 integrate householder with dense alp containers (#93)
* compile and test with -make test_alp_zhetrd_alp_reference && tests/smoke/alp_zhetrd_alp_reference -make test_alp_zhetrd_complex_alp_reference && tests/smoke/alp_zhetrd_complex_alp_reference * numerical correctness checked by Frabenus norm H - QTQt * additional check -matrix Q check, if it is orthognal -matrix QTQt check, if it is a correct solution Non-householder related fixes * print matrices in python like style * print matrices with max precision 1.e-10 * initialize identity using view::diagonal * fix missing dot() and norm2() scalar conversion * outer enable complex conjugate when complex numbers used * conjugate function on matrices and vectors added * fix the view-related type traits for Hermitian structure * square view fix * move alp_zhetrd to smoke tests * add checks of alp_zhetrd in smoketests.sh (complex and real version) * view() is General by default issues 1) Hermitian currently using Square storage 2) TridiagonalHermitian using Square storage 3) symmetric view of general, done using eWiseLambda 4) fold on non-matching-structures Matrices does not work 5) transpose on Symmetric does not work when reused on Square Co-authored-by: Denis Jelovina <denis.jelovina@huawei.com> Co-authored-by: Vladimir Dimic <vladimir.dimic@huawei.com>
1 parent dab75eb commit be17c04

File tree

10 files changed

+590
-180
lines changed

10 files changed

+590
-180
lines changed

include/alp/algorithms/householder_tridiag.hpp

Lines changed: 130 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <sstream>
1919

2020
#include <alp.hpp>
21+
#include <graphblas/utils/iscomplex.hpp> // use from grb
22+
#include "../tests/utils/print_alp_containers.hpp"
2123

2224
namespace alp {
2325

@@ -40,31 +42,59 @@ namespace alp {
4042
*
4143
*/
4244
template<
43-
typename D = double,
45+
typename D,
46+
typename SymmOrHermType,
47+
typename SymmOrHermTridiagonalType,
48+
typename OrthogonalType,
4449
class Ring = Semiring< operators::add< D >, operators::mul< D >, identities::zero, identities::one >,
4550
class Minus = operators::subtract< D >,
46-
class Divide = operators::divide< D > >
51+
class Divide = operators::divide< D >
52+
>
4753
RC householder_tridiag(
48-
Matrix< D, structures::Orthogonal, Dense > &Q,
49-
Matrix< D, structures::SymmetricTridiagonal, Dense > &T, // Need to be add this once alp -> alp is done
50-
const Matrix< D, structures::Symmetric, Dense > &H,
54+
Matrix< D, OrthogonalType, Dense > &Q,
55+
Matrix< D, SymmOrHermTridiagonalType, Dense > &T,
56+
Matrix< D, SymmOrHermType, Dense > &H,
5157
const Ring & ring = Ring(),
5258
const Minus & minus = Minus(),
5359
const Divide & divide = Divide() ) {
5460

5561
RC rc = SUCCESS;
5662

5763
const Scalar< D > zero( ring.template getZero< D >() );
64+
const Scalar< D > one( ring.template getOne< D >() );
5865
const size_t n = nrows( H );
5966

6067
// Q = identity( n )
61-
rc = set( Q , structures::constant::I( n ) );
68+
rc = alp::set( Q, zero );
69+
auto Qdiag = alp::get_view< alp::view::diagonal >( Q );
70+
rc = rc ? rc : alp::set( Qdiag, one );
71+
if( rc != SUCCESS ) {
72+
std::cerr << " set( Q, I ) failed\n";
73+
return rc;
74+
}
6275

6376
// Out of place specification of the computation
64-
Matrix< D, structures::Symmetric, Dense > RR( n );
77+
Matrix< D, SymmOrHermType, Dense > RR( n );
78+
6579
rc = set( RR, H );
80+
if( rc != SUCCESS ) {
81+
std::cerr << " set( RR, H ) failed\n";
82+
return rc;
83+
}
84+
#ifdef DEBUG
85+
print_matrix( " << RR >> ", RR );
86+
#endif
87+
88+
// a temporary for storing the mxm result
89+
Matrix< D, OrthogonalType, Dense > Qtmp( n, n );
6690

6791
for( size_t k = 0; k < n - 2; ++k ) {
92+
#ifdef DEBUG
93+
std::string matname(" << RR(");
94+
matname = matname + std::to_string(k);
95+
matname = matname + std::string( ") >> ");
96+
print_matrix( matname , RR );
97+
#endif
6898

6999
const size_t m = n - k - 1;
70100

@@ -73,82 +103,126 @@ namespace alp {
73103
// alpha = norm( v ) * v[ 0 ] / norm( v[ 0 ] )
74104
// v = v - alpha * e1
75105
// v = v / norm ( v )
76-
Vector< D, structures::General, Dense > v;
77-
rc = set( v, get_view( RR, utils::range( k + 1, n ), k ) );
78106

79-
Scalar< D > alpha;
107+
auto v_view = get_view( RR, k, utils::range( k + 1, n ) );
108+
Vector< D, structures::General, Dense > v( n - ( k + 1 ) );
109+
rc = set( v, v_view );
110+
if( rc != SUCCESS ) {
111+
std::cerr << " set( v, view ) failed\n";
112+
return rc;
113+
}
114+
115+
Scalar< D > alpha( zero );
80116
rc = norm2( alpha, v, ring );
117+
if( rc != SUCCESS ) {
118+
std::cerr << " norm2( alpha, v, ring ) failed\n";
119+
return rc;
120+
}
81121

82122
rc = eWiseLambda(
83-
[ &v, &alpha ]( const size_t i ) {
84-
if ( i == 0 ) {
85-
Scalar< D > norm_v0( std::abs( v[ i ] ) );
86-
foldl(alpha, v [ i ], ring.getMultiplicativeOperator() );
87-
foldl(alpha, norm_v0, divide );
88-
foldl(v [ i ], alpha, minus );
89-
}
90-
},
91-
v );
92-
93-
Scalar< D > norm_v;
123+
[ &alpha, &ring, &divide, &minus ]( const size_t i, D &val ) {
124+
if ( i == 0 ) {
125+
Scalar< D > norm_v0( std::abs( val ) );
126+
Scalar< D > val_scalar( val );
127+
foldl( alpha, val_scalar, ring.getMultiplicativeOperator() );
128+
foldl( alpha, norm_v0, divide );
129+
foldl( val_scalar, alpha, minus );
130+
val = *val_scalar;
131+
}
132+
},
133+
v
134+
);
135+
if( rc != SUCCESS ) {
136+
std::cerr << " eWiseLambda( lambda, v ) failed\n";
137+
return rc;
138+
}
139+
140+
Scalar< D > norm_v( zero );
94141
rc = norm2( norm_v, v, ring );
142+
if( rc != SUCCESS ) {
143+
std::cerr << " norm2( norm_v, v, ring ) failed\n";
144+
return rc;
145+
}
95146

96147
rc = foldl(v, norm_v, divide );
148+
#ifdef DEBUG
149+
print_vector( " v = ", v );
150+
#endif
97151
// ===== End Computing v =====
98152

99-
100153
// ===== Calculate reflector Qk =====
101154
// Q_k = identity( n )
102-
Matrix< D, structures::Symmetric, Dense > Qk( m );
103-
rc = set( Qk, structures::constant::I( m ) );
155+
Matrix< D, SymmOrHermType, Dense > Qk( n );
156+
rc = alp::set( Qk, zero );
157+
auto Qk_diag = alp::get_view< alp::view::diagonal >( Qk );
158+
rc = rc ? rc : alp::set( Qk_diag, one );
104159

105-
Matrix< D, structures::Symmetric, Dense > vvt( m );
106-
rc = set(vvt, zero );
107-
// vvt = v * v^T
108-
rc = outer( vvt, v, v, ring );
160+
// this part can be rewriten without temp matrix using functors
161+
Matrix< D, SymmOrHermType, Dense > vvt( m );
109162

163+
rc = rc ? rc : set( vvt, outer( v, ring.getMultiplicativeOperator() ) );
110164
// vvt = 2 * vvt
111-
rc = foldr( Scalar< D >( 2 ), vvt, ring.getMultiplicativeOperator() );
165+
rc = rc ? rc : foldr( Scalar< D >( 2 ), vvt, ring.getMultiplicativeOperator() );
166+
167+
168+
#ifdef DEBUG
169+
print_matrix( " vvt ", vvt );
170+
#endif
112171

113172
// Qk = Qk - vvt ( expanded: I - 2 * vvt )
114-
rc = foldl( Qk, vvt, minus );
173+
auto Qk_view = get_view< SymmOrHermType >( Qk, utils::range( k + 1, n ), utils::range( k + 1, n ) );
174+
if ( grb::utils::is_complex< D >::value ) {
175+
rc = rc ? rc : foldl( Qk_view, alp::get_view< alp::view::transpose >( vvt ), minus );
176+
} else {
177+
rc = rc ? rc : foldl( Qk_view, vvt, minus );
178+
}
179+
180+
#ifdef DEBUG
181+
print_matrix( " << Qk >> ", Qk );
182+
#endif
115183
// ===== End of Calculate reflector Qk ====
116184

117185
// ===== Update R =====
118-
// Rk = Qk * Rk * Qk^T
119-
120-
// get a view over RR (temporary of R)
121-
auto Rk = get_view( RR, range( k + 1, n ), range( k + 1, n ) );
122-
123-
// QkRk = Qk * Rk
124-
Matrix< D, structures::Square, Dense > QkRk( m );
125-
rc = set( QkRk, zero );
126-
rc = mxm( QkRk, Qk, Rk, ring );
127-
128-
// Rk = QkRk * QkT
129-
rc = set( Rk, zero );
130-
rc = mxm( Rk, QkRk, Qk, ring );
186+
// Rk = Qk * Rk * Qk
187+
188+
// RRQk = RR * Qk
189+
Matrix< D, structures::Square, Dense > RRQk( n );
190+
rc = rc ? rc : set( RRQk, zero );
191+
rc = rc ? rc : mxm( RRQk, RR, Qk, ring );
192+
if( rc != SUCCESS ) {
193+
std::cerr << " mxm( RRQk, RR, Qk, ring ); failed\n";
194+
return rc;
195+
}
196+
#ifdef DEBUG
197+
print_matrix( " << RR x Qk = >> ", RRQk );
198+
#endif
199+
// RR = Qk * RRQk
200+
rc = rc ? rc : set( RR, zero );
201+
rc = rc ? rc : mxm( RR, Qk, RRQk, ring );
202+
203+
#ifdef DEBUG
204+
print_matrix( " << RR( updated ) >> ", RR );
205+
#endif
131206
// ===== End of Update R =====
132207

133208
// ===== Update Q =====
134-
// Q = Q * conjugate( QkT )
135-
// a temporary for storing the mxm result
136-
Matrix< D, structures::Orthogonal, Dense > Qtmp( m, m );
137-
// a view over smaller portion of Q
138-
auto Qprim = get_view( Q, range( k + 1, n ), range( k + 1, n ) );
139-
140-
// Qtmp = Qprim * QkT
141-
rc = set( Qtmp, zero );
142-
rc = mxm( Qtmp, Qprim, Qk, ring );
143-
144-
// Qprim = Qtmp
145-
rc = set( Qprim, Qtmp );
209+
// Q = Q * Qk
210+
211+
// Qtmp = Q * Qk
212+
rc = rc ? rc : set( Qtmp, zero );
213+
rc = rc ? rc : mxm( Qtmp, Q, Qk, ring );
214+
215+
// Q = Qtmp
216+
rc = rc ? rc : set( Q, Qtmp );
217+
#ifdef DEBUG
218+
print_matrix( " << Q updated >> ", Q );
219+
#endif
146220
// ===== End of Update Q =====
147221
}
148222

149223
// T = RR
150-
rc = set( T, get_view< structures::SymmetricTridiagonal > ( RR ) );
151224

225+
rc = rc ? rc : set( T, get_view< SymmOrHermTridiagonalType > ( RR ) );
152226
return rc;
153227
}
154228
} // namespace algorithms

include/alp/reference/blas1.hpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "vector.hpp"
3434
#include "blas0.hpp"
3535
#include "blas2.hpp"
36+
#include <graphblas/utils/iscomplex.hpp> // use from grb
3637

3738
#ifndef NO_CAST_ASSERT
3839
#define NO_CAST_ASSERT( x, y, z ) \
@@ -2508,7 +2509,11 @@ namespace alp {
25082509
std::function< void( typename AddMonoid::D3 &, const size_t, const size_t ) > data_lambda =
25092510
[ &x, &y, &anyOp ]( typename AddMonoid::D3 &result, const size_t i, const size_t j ) {
25102511
(void) j;
2511-
internal::apply( result, x[ i ], y[ i ], anyOp );
2512+
internal::apply(
2513+
result, x[ i ],
2514+
grb::utils::is_complex< InputType2 >::conjugate( y[ i ] ),
2515+
anyOp
2516+
);
25122517
};
25132518

25142519
std::function< bool() > init_lambda =
@@ -2634,15 +2639,15 @@ namespace alp {
26342639
void >::type * const = NULL
26352640
) {
26362641
Scalar< IOType, structures::General, backend > res( x );
2637-
RC rc = alp::dot< descr >( x,
2642+
RC rc = alp::dot< descr >( res,
26382643
left, right,
26392644
ring.getAdditiveMonoid(),
26402645
ring.getMultiplicativeOperator()
26412646
);
26422647
if( rc != SUCCESS ) {
26432648
return rc;
26442649
}
2645-
/** \internal \todo extract res.value into x */
2650+
x = *res;
26462651
return SUCCESS;
26472652
}
26482653

@@ -2896,16 +2901,18 @@ namespace alp {
28962901
class Ring,
28972902
Backend backend
28982903
>
2899-
RC norm2( Scalar< OutputType, OutputStructure, backend > &x,
2904+
RC norm2(
2905+
Scalar< OutputType, OutputStructure, backend > &x,
29002906
const Vector< InputType, InputStructure, Density::Dense, InputView, InputImfR, InputImfC, backend > &y,
29012907
const Ring &ring = Ring(),
29022908
const typename std::enable_if<
2903-
std::is_floating_point< OutputType >::value,
2904-
void >::type * const = NULL
2909+
std::is_floating_point< OutputType >::value || grb::utils::is_complex< OutputType >::value,
2910+
void
2911+
>::type * const = NULL
29052912
) {
29062913
RC ret = alp::dot< descr >( x, y, y, ring );
29072914
if( ret == SUCCESS ) {
2908-
x = sqrt( x );
2915+
*x = sqrt( *x );
29092916
}
29102917
return ret;
29112918
}
@@ -2923,15 +2930,16 @@ namespace alp {
29232930
const Vector< InputType, InputStructure, Density::Dense, InputView, InputImfR, InputImfC, backend > &y,
29242931
const Ring &ring = Ring(),
29252932
const typename std::enable_if<
2926-
std::is_floating_point< OutputType >::value,
2927-
void >::type * const = nullptr
2933+
std::is_floating_point< OutputType >::value || grb::utils::is_complex< OutputType >::value,
2934+
void
2935+
>::type * const = nullptr
29282936
) {
29292937
Scalar< OutputType, structures::General, reference > res( x );
29302938
RC rc = norm2( res, y, ring );
29312939
if( rc != SUCCESS ) {
29322940
return rc;
29332941
}
2934-
/** \internal \todo extract res.value into x */
2942+
x = *res;
29352943
return SUCCESS;
29362944
}
29372945

0 commit comments

Comments
 (0)