|
| 1 | +/* |
| 2 | + * Copyright 2022 Huawei Technologies Co., Ltd. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include <iostream> |
| 18 | +#include <sstream> |
| 19 | +#include <vector> |
| 20 | +#ifdef _COMPLEX |
| 21 | +#include <complex> |
| 22 | +#include <cmath> |
| 23 | +#include <iomanip> |
| 24 | +#endif |
| 25 | + |
| 26 | +#include <alp.hpp> |
| 27 | +#include <alp/algorithms/householder_tridiag.hpp> |
| 28 | +#include <alp/algorithms/symm_tridiag_eigensolver.hpp> |
| 29 | +#include <graphblas/utils/iscomplex.hpp> // use from grb |
| 30 | +#include "../utils/print_alp_containers.hpp" |
| 31 | + |
| 32 | +using namespace alp; |
| 33 | + |
| 34 | +using BaseScalarType = double; |
| 35 | +using Orthogonal = structures::Orthogonal; |
| 36 | + |
| 37 | +#ifdef _COMPLEX |
| 38 | +using ScalarType = std::complex< BaseScalarType >; |
| 39 | +//not fully implemented structures |
| 40 | +using HermitianOrSymmetricTridiagonal = structures::HermitianTridiagonal; |
| 41 | +using HermitianOrSymmetric = structures::Hermitian; |
| 42 | +#else |
| 43 | +using ScalarType = BaseScalarType; |
| 44 | +using HermitianOrSymmetricTridiagonal = structures::SymmetricTridiagonal; |
| 45 | +//fully implemented structures |
| 46 | +using HermitianOrSymmetric = structures::Symmetric; |
| 47 | +#endif |
| 48 | + |
| 49 | +constexpr BaseScalarType tol = 1.e-10; |
| 50 | +constexpr size_t RNDSEED = 1; |
| 51 | + |
| 52 | +/** Generate symmetric-hermitian matrix in a full storage container */ |
| 53 | +template< typename T > |
| 54 | +std::vector< T > generate_symmherm_matrix_data( |
| 55 | + size_t N, |
| 56 | + const typename std::enable_if< |
| 57 | + grb::utils::is_complex< T >::value, |
| 58 | + void |
| 59 | + >::type * const = nullptr |
| 60 | +) { |
| 61 | + std::vector< T > data( N * N ); |
| 62 | + std::fill( data.begin(), data.end(), static_cast< T >( 0 ) ); |
| 63 | + for( size_t i = 0; i < N; ++i ) { |
| 64 | + for( size_t j = i; j < N; ++j ) { |
| 65 | + T val( std::rand(), std::rand() ); |
| 66 | + data[ i * N + j ] = val / std::abs( val ); |
| 67 | + data[ j * N + i ] += grb::utils::is_complex< T >::conjugate( data[ i * N + j ] ); |
| 68 | + } |
| 69 | + } |
| 70 | + return data; |
| 71 | +} |
| 72 | + |
| 73 | +/** Generate upper/lower triangular part of a Symmetric matrix */ |
| 74 | +template< typename T > |
| 75 | +std::vector< T > generate_symmherm_matrix_data( |
| 76 | + size_t N, |
| 77 | + const typename std::enable_if< |
| 78 | + !grb::utils::is_complex< T >::value, |
| 79 | + void |
| 80 | + >::type * const = nullptr |
| 81 | +) { |
| 82 | + std::vector< T > data( ( N * ( N + 1 ) ) / 2 ); |
| 83 | + std::srand( RNDSEED ); |
| 84 | + size_t k = 0; |
| 85 | + for( size_t i = 0; i < N; ++i ) { |
| 86 | + for( size_t j = i; j < N; ++j ) { |
| 87 | + //data[ k ] = static_cast< T >( i + j*j ); // easily reproducible |
| 88 | + data[ k ] = static_cast< T >( std::rand() ) / RAND_MAX; |
| 89 | + ++k; |
| 90 | + } |
| 91 | + } |
| 92 | + return data; |
| 93 | +} |
| 94 | + |
| 95 | +/** Check if rows/columns or matrix Q are orthogonal */ |
| 96 | +template< |
| 97 | + typename T, |
| 98 | + typename Structure, |
| 99 | + typename ViewType, |
| 100 | + std::enable_if_t< |
| 101 | + structures::is_a< Structure, structures::Orthogonal >::value |
| 102 | + > * = nullptr, |
| 103 | + class Ring = Semiring< operators::add< T >, operators::mul< T >, identities::zero, identities::one >, |
| 104 | + class Minus = operators::subtract< T > |
| 105 | +> |
| 106 | +RC check_overlap( |
| 107 | + alp::Matrix< T, Structure, alp::Density::Dense, ViewType > &Q, |
| 108 | + const Ring &ring = Ring(), |
| 109 | + const Minus &minus = Minus() |
| 110 | +) { |
| 111 | + const Scalar< T > zero( ring.template getZero< T >() ); |
| 112 | + const Scalar< T > one( ring.template getOne< T >() ); |
| 113 | + |
| 114 | + RC rc = SUCCESS; |
| 115 | + const size_t n = nrows( Q ); |
| 116 | + |
| 117 | + // check if QxQt == I |
| 118 | + alp::Matrix< T, Structure, alp::Density::Dense, ViewType > Qtmp( n ); |
| 119 | + rc = rc ? rc : set( Qtmp, zero ); |
| 120 | + rc = rc ? rc : mxm( |
| 121 | + Qtmp, |
| 122 | + Q, |
| 123 | + conjugate( alp::get_view< alp::view::transpose >( Q ) ), |
| 124 | + ring |
| 125 | + ); |
| 126 | + // For Identity we use Structure (Orthogonal structure), |
| 127 | + // as later we use fold with Qtmp (Orthogonal matrix) |
| 128 | + Matrix< T, Structure, Dense > Identity( n ); |
| 129 | + rc = rc ? rc : alp::set( Identity, zero ); |
| 130 | + auto id_diag = alp::get_view< alp::view::diagonal >( Identity ); |
| 131 | + rc = rc ? rc : alp::set( id_diag, one ); |
| 132 | + rc = rc ? rc : foldl( Qtmp, Identity, minus ); |
| 133 | + |
| 134 | + //Frobenius norm |
| 135 | + T fnorm = ring.template getZero< T >(); |
| 136 | + rc = rc ? rc : alp::eWiseLambda( |
| 137 | + [ &fnorm, &ring ]( const size_t i, const size_t j, T &val ) { |
| 138 | + (void) i; |
| 139 | + (void) j; |
| 140 | + internal::foldl( fnorm, val * val, ring.getAdditiveOperator() ); |
| 141 | + }, |
| 142 | + Qtmp |
| 143 | + ); |
| 144 | + fnorm = std::sqrt( fnorm ); |
| 145 | + |
| 146 | +#ifdef DEBUG |
| 147 | + std::cout << " FrobeniusNorm(QQt - I) = " << std::abs( fnorm ) << "\n"; |
| 148 | +#endif |
| 149 | + if( tol < std::abs( fnorm ) ) { |
| 150 | + std::cout << "The Frobenius norm is too large: " << std::abs( fnorm ) << ".\n"; |
| 151 | + return FAILED; |
| 152 | + } |
| 153 | + |
| 154 | + return rc; |
| 155 | +} |
| 156 | + |
| 157 | + |
| 158 | +/** Check the solution by calculating A x Q - Q x diag(d) */ |
| 159 | +template< |
| 160 | + typename D, |
| 161 | + typename SymmOrHermTridiagonalType, |
| 162 | + typename OrthogonalType, |
| 163 | + typename SymmHermTrdiViewType, |
| 164 | + typename OrthViewType, |
| 165 | + typename SymmHermTrdiImfR, |
| 166 | + typename SymmHermTrdiImfC, |
| 167 | + typename OrthViewImfR, |
| 168 | + typename OrthViewImfC, |
| 169 | + typename VecViewType, |
| 170 | + typename VecImfR, |
| 171 | + typename VecImfC, |
| 172 | + class Ring = Semiring< operators::add< D >, operators::mul< D >, identities::zero, identities::one >, |
| 173 | + class Minus = operators::subtract< D >, |
| 174 | + class Divide = operators::divide< D > |
| 175 | +> |
| 176 | +RC check_solution( |
| 177 | + Matrix< D, SymmOrHermTridiagonalType, Dense, SymmHermTrdiViewType, SymmHermTrdiImfR, SymmHermTrdiImfC > &T, |
| 178 | + Matrix< D, OrthogonalType, Dense, OrthViewType, OrthViewImfR, OrthViewImfC > &Q, |
| 179 | + Vector< D, structures::General, Dense, VecViewType, VecImfR, VecImfC > &d, |
| 180 | + const Ring &ring = Ring(), |
| 181 | + const Minus &minus = Minus(), |
| 182 | + const Divide ÷ = Divide() |
| 183 | +) { |
| 184 | + (void) ring; |
| 185 | + (void) minus; |
| 186 | + (void) divide; |
| 187 | + RC rc = SUCCESS; |
| 188 | + |
| 189 | + const size_t n = nrows( Q ); |
| 190 | + |
| 191 | +#ifdef DEBUG |
| 192 | + print_matrix( " T ", T ); |
| 193 | + print_matrix( " Q ", Q ); |
| 194 | + print_vector( " d ", d ); |
| 195 | +#endif |
| 196 | + |
| 197 | + alp::Matrix< D, alp::structures::Square, alp::Density::Dense > Left( n ); |
| 198 | + alp::Matrix< D, alp::structures::Square, alp::Density::Dense > Right( n ); |
| 199 | + alp::Matrix< D, alp::structures::Square, alp::Density::Dense > Dmat( n ); |
| 200 | + const Scalar< D > zero( ring.template getZero< D >() ); |
| 201 | + const Scalar< D > one( ring.template getOne< D >() ); |
| 202 | + |
| 203 | + rc = rc ? rc : set( Left, zero ); |
| 204 | + rc = rc ? rc : mxm( Left, T, Q, ring ); |
| 205 | + |
| 206 | + rc = rc ? rc : set( Dmat, zero ); |
| 207 | + auto D_diag = alp::get_view< alp::view::diagonal >( Dmat ); |
| 208 | + rc = rc ? rc : set( D_diag, d ); |
| 209 | + rc = rc ? rc : set( Right, zero ); |
| 210 | + rc = rc ? rc : mxm( Right, Q, Dmat, ring ); |
| 211 | +#ifdef DEBUG |
| 212 | + print_matrix( " TxQ ", Left ); |
| 213 | + print_matrix( " QxD ", Right ), |
| 214 | +#endif |
| 215 | + rc = rc ? rc : foldl( Left, Right, minus ); |
| 216 | + |
| 217 | + //Frobenius norm |
| 218 | + D fnorm = ring.template getZero< D >(); |
| 219 | + rc = rc ? rc : alp::eWiseLambda( |
| 220 | + [ &fnorm, &ring ]( const size_t i, const size_t j, D &val ) { |
| 221 | + (void) i; |
| 222 | + (void) j; |
| 223 | + internal::foldl( fnorm, val * val, ring.getAdditiveOperator() ); |
| 224 | + }, |
| 225 | + Left |
| 226 | + ); |
| 227 | + fnorm = std::sqrt( fnorm ); |
| 228 | + |
| 229 | +#ifdef DEBUG |
| 230 | + std::cout << " FrobeniusNorm(AQ-QD) = " << std::abs( fnorm ) << "\n"; |
| 231 | +#endif |
| 232 | + if( tol < std::abs( fnorm ) ) { |
| 233 | + std::cout << "The Frobenius norm is too large: " << std::abs( fnorm ) << ".\n"; |
| 234 | + return FAILED; |
| 235 | + } |
| 236 | + |
| 237 | + return rc; |
| 238 | +} |
| 239 | + |
| 240 | +void alp_program( const size_t &unit, alp::RC &rc ) { |
| 241 | + rc = SUCCESS; |
| 242 | + |
| 243 | + alp::Semiring< |
| 244 | + alp::operators::add< ScalarType >, |
| 245 | + alp::operators::mul< ScalarType >, |
| 246 | + alp::identities::zero, |
| 247 | + alp::identities::one |
| 248 | + > ring; |
| 249 | + const Scalar< ScalarType > zero( ring.template getZero< ScalarType >() ); |
| 250 | + |
| 251 | + // dimensions of sqare matrices H, Q and R |
| 252 | + size_t N = unit; |
| 253 | + |
| 254 | + alp::Matrix< ScalarType, Orthogonal > Q( N ); //output eigenvectors |
| 255 | + alp::Matrix< ScalarType, Orthogonal > Q1( N ); //temp orthogonal matrix |
| 256 | + alp::Matrix< ScalarType, Orthogonal > Q2( N ); //temp orthogonal matrix |
| 257 | + alp::Matrix< ScalarType, HermitianOrSymmetricTridiagonal > T( N ); //temptridiagonal matrix |
| 258 | + alp::Matrix< ScalarType, HermitianOrSymmetric > H( N ); //input matrix |
| 259 | + Vector< ScalarType, structures::General, Dense > d( N ); //output eigenvalues |
| 260 | + { |
| 261 | + std::srand( RNDSEED ); |
| 262 | + auto matrix_data = generate_symmherm_matrix_data< ScalarType >( N ); |
| 263 | + rc = rc ? rc : alp::buildMatrix( H, matrix_data.begin(), matrix_data.end() ); |
| 264 | + } |
| 265 | +#ifdef DEBUG |
| 266 | + print_matrix( " input matrix H ", H ); |
| 267 | +#endif |
| 268 | + |
| 269 | + rc = rc ? rc : set( Q1, zero ); |
| 270 | + rc = rc ? rc : set( Q2, zero ); |
| 271 | + rc = rc ? rc : set( Q, zero ); |
| 272 | + |
| 273 | + rc = rc ? rc : algorithms::householder_tridiag( Q1, T, H, ring ); |
| 274 | + rc = rc ? rc : algorithms::symm_tridiag_dac_eigensolver( T, Q2, d, ring ); |
| 275 | + rc = rc ? rc : alp::mxm( Q, Q1, Q2, ring ); |
| 276 | + |
| 277 | +#ifdef DEBUG |
| 278 | + print_matrix( " Q1 ", Q1 ); |
| 279 | + print_matrix( " Q2 ", Q2 ); |
| 280 | + print_matrix( " Q ", Q ); |
| 281 | + print_matrix( " T ", T ); |
| 282 | +#endif |
| 283 | + |
| 284 | + // the algorithm should return correct eigenvalues |
| 285 | + // but for larger matrices (n>20) a more stable calculations |
| 286 | + // of eigenvectors is needed |
| 287 | + // therefore we disable numerical correctness check in this version |
| 288 | + |
| 289 | + // rc = check_overlap( Q ); |
| 290 | + // if( rc != SUCCESS ) { |
| 291 | + // std::cout << "Error: mratrix Q is not orthogonal\n"; |
| 292 | + // } |
| 293 | + |
| 294 | + // rc = check_solution( H, Q, d ); |
| 295 | + // if( rc != SUCCESS ) { |
| 296 | + // std::cout << "Error: solution numerically wrong\n"; |
| 297 | + // } |
| 298 | +} |
| 299 | + |
| 300 | +int main( int argc, char **argv ) { |
| 301 | + // defaults |
| 302 | + bool printUsage = false; |
| 303 | + size_t in = 5; |
| 304 | + |
| 305 | + // error checking |
| 306 | + if( argc > 2 ) { |
| 307 | + printUsage = true; |
| 308 | + } |
| 309 | + if( argc == 2 ) { |
| 310 | + size_t read; |
| 311 | + std::istringstream ss( argv[ 1 ] ); |
| 312 | + if( ! ( ss >> read ) ) { |
| 313 | + std::cerr << "Error parsing first argument\n"; |
| 314 | + printUsage = true; |
| 315 | + } else if( ! ss.eof() ) { |
| 316 | + std::cerr << "Error parsing first argument\n"; |
| 317 | + printUsage = true; |
| 318 | + } else if( read % 2 != 0 ) { |
| 319 | + std::cerr << "Given value for n is odd\n"; |
| 320 | + printUsage = true; |
| 321 | + } else { |
| 322 | + // all OK |
| 323 | + in = read; |
| 324 | + } |
| 325 | + } |
| 326 | + if( printUsage ) { |
| 327 | + std::cerr << "Usage: " << argv[ 0 ] << " [n]\n"; |
| 328 | + std::cerr << " -n (optional, default is 100): an even integer, the " |
| 329 | + "test size.\n"; |
| 330 | + return 1; |
| 331 | + } |
| 332 | + |
| 333 | + std::cout << "This is functional test " << argv[ 0 ] << "\n"; |
| 334 | + alp::Launcher< AUTOMATIC > launcher; |
| 335 | + alp::RC out; |
| 336 | + if( launcher.exec( &alp_program, in, out, true ) != SUCCESS ) { |
| 337 | + std::cerr << "Launching test FAILED\n"; |
| 338 | + return 255; |
| 339 | + } |
| 340 | + if( out != SUCCESS ) { |
| 341 | + std::cerr << "Test FAILED (" << alp::toString( out ) << ")" << std::endl; |
| 342 | + } else { |
| 343 | + std::cout << "Test OK" << std::endl; |
| 344 | + } |
| 345 | + return 0; |
| 346 | +} |
0 commit comments