8
8
#include < TiledArray/conversions/btas.h>
9
9
#include < TiledArray/expressions/einsum.h>
10
10
#include < tiledarray.h>
11
- #include < random>
12
11
13
12
namespace TiledArray ::cp {
13
+
14
14
namespace detail {
15
- // A seed for the random number generator.
16
- static inline unsigned int & random_seed_accessor () {
17
- static unsigned int value = 3 ;
18
- return value;
19
- }
20
-
21
- // given a rank and block size, this computes a
22
- // trange for the rank dimension to be used to make the CP factors.
23
- static inline TiledRange1 compute_trange1 (size_t rank, size_t rank_block_size) {
24
- std::size_t nblocks = (rank + rank_block_size - 1 ) / rank_block_size;
25
- auto dv = std::div ((int )(rank + nblocks - 1 ), (int )nblocks);
26
- auto avg_block_size = dv.quot - 1 , num_avg_plus_one = dv.rem + 1 ;
27
-
28
- TiledArray::TiledRange1 new_trange1;
29
- {
30
- std::vector<std::size_t > new_trange1_v;
31
- new_trange1_v.reserve (nblocks + 1 );
32
- auto block_counter = 0 ;
33
- for (auto i = 0 ; i < num_avg_plus_one;
34
- ++i, block_counter += avg_block_size + 1 ) {
35
- new_trange1_v.emplace_back (block_counter);
36
- }
37
- for (auto i = num_avg_plus_one; i < nblocks;
38
- ++i, block_counter += avg_block_size) {
39
- new_trange1_v.emplace_back (block_counter);
40
- }
41
- new_trange1_v.emplace_back (rank);
42
- new_trange1 =
43
- TiledArray::TiledRange1 (new_trange1_v.begin (), new_trange1_v.end ());
44
- }
45
- return new_trange1;
46
- }
47
15
48
16
static inline char intToAlphabet (int i) { return static_cast <char >(' a' + i); }
49
17
@@ -111,13 +79,13 @@ class CP {
111
79
if (build_rank) {
112
80
size_t cur_rank = 1 ;
113
81
do {
114
- rank_trange = detail ::compute_trange1 (cur_rank, rank_block_size);
82
+ rank_trange = TiledArray ::compute_trange1 (cur_rank, rank_block_size);
115
83
build_guess (cur_rank, rank_trange);
116
84
ALS (cur_rank, 100 , verbose);
117
85
++cur_rank;
118
86
} while (cur_rank < rank);
119
87
} else {
120
- rank_trange = detail ::compute_trange1 (rank, rank_block_size);
88
+ rank_trange = TiledArray ::compute_trange1 (rank, rank_block_size);
121
89
build_guess (rank, rank_trange);
122
90
ALS (rank, 100 , verbose);
123
91
}
@@ -143,7 +111,7 @@ class CP {
143
111
double epsilon = 1.0 ;
144
112
fit_tol = epsilonALS;
145
113
do {
146
- auto rank_trange = detail:: compute_trange1 (cur_rank, rank_block_size);
114
+ auto rank_trange = compute_trange1 (cur_rank, rank_block_size);
147
115
build_guess (cur_rank, rank_trange);
148
116
ALS (cur_rank, 100 , verbose);
149
117
++cur_rank;
@@ -196,9 +164,10 @@ class CP {
196
164
final_fit, // The final fit of the ALS
197
165
// optimization at fixed rank.
198
166
fit_tol, // Tolerance for the ALS solver
199
- converged_num, // How many times the ALS solver
200
- // has changed less than the tolerance
201
167
norm_reference; // used in determining the CP fit.
168
+ std::size_t converged_num =
169
+ 0 ; // How many times the ALS solver
170
+ // has changed less than the tolerance in a row
202
171
203
172
// / This function is determined by the specific CP solver.
204
173
// / builds the rank @c rank CP approximation and stores
@@ -227,14 +196,12 @@ class CP {
227
196
auto lambda = std::vector<typename Tile::value_type>(
228
197
rank, (typename Tile::value_type)0 );
229
198
if (world.rank () == 0 ) {
230
- std::mt19937 generator (detail::random_seed_accessor ());
231
- std::uniform_real_distribution<> distribution (-1.0 , 1.0 );
232
199
auto factor_ptr = factor.data ();
233
200
size_t offset = 0 ;
234
201
for (auto r = 0 ; r < rank; ++r, offset += mode_size) {
235
202
auto lam_ptr = lambda.data () + r;
236
203
for (auto m = offset; m < offset + mode_size; ++m) {
237
- auto val = distribution (generator);
204
+ auto val = TiledArray::drand () * 2 - 1 ; // random number in [-1,1]
238
205
*(factor_ptr + m) = val;
239
206
*lam_ptr += val * val;
240
207
}
@@ -364,7 +331,7 @@ class CP {
364
331
// / \returns bool : is the change in fit less than the ALS tolerance?
365
332
virtual bool check_fit (bool verbose = false ) {
366
333
// Compute the inner product T * T_CP
367
- double inner_prod = MTtKRP (" r,n" ).dot (unNormalized_Factor (" r,n" ));
334
+ const auto ref_dot_cp = MTtKRP (" r,n" ).dot (unNormalized_Factor (" r,n" ));
368
335
// compute the square of the CP tensor (can use the grammian)
369
336
auto factor_norm = [&]() {
370
337
auto gram_ptr = partial_grammian.begin ();
@@ -380,27 +347,39 @@ class CP {
380
347
return result;
381
348
};
382
349
// compute the error in the loss function and find the fit
383
- double normFactors = factor_norm (),
384
- normResidual =
385
- sqrt (abs (norm_reference * norm_reference +
386
- normFactors * normFactors - 2.0 * inner_prod)),
387
- fit = 1.0 - (normResidual / norm_reference),
388
- fit_change = abs (prev_fit - fit);
350
+ const auto norm_cp = factor_norm (); // ||T_CP||_2
351
+ const auto squared_norm_error = norm_reference * norm_reference +
352
+ norm_cp * norm_cp -
353
+ 2.0 * ref_dot_cp; // ||T - T_CP||_2^2
354
+ // N.B. squared_norm_error is very noisy
355
+ // TA_ASSERT(squared_norm_error >= - 1e-8);
356
+ const auto norm_error = sqrt (abs (squared_norm_error));
357
+ const auto fit = 1.0 - (norm_error / norm_reference);
358
+ const auto fit_change = fit - prev_fit;
389
359
prev_fit = fit;
390
360
// print fit data if required
391
361
if (verbose) {
392
- std::cout << fit << " \t " << fit_change << std::endl;
362
+ std::cout << MTtKRP.world ().rank () << " : fit=" << fit
363
+ << " fit_change=" << fit_change << std::endl;
393
364
}
394
365
395
366
// if the change in fit is less than the tolerance try to return true.
396
- if (fit_change < fit_tol) {
367
+ if (abs ( fit_change) < fit_tol) {
397
368
converged_num++;
398
369
if (converged_num == 2 ) {
399
370
converged_num = 0 ;
400
371
final_fit = prev_fit;
401
372
prev_fit = 1.0 ;
373
+ if (verbose)
374
+ std::cout << MTtKRP.world ().rank () << " : converged" << std::endl;
402
375
return true ;
376
+ } else {
377
+ TA_ASSERT (converged_num == 1 );
378
+ if (verbose)
379
+ std::cout << MTtKRP.world ().rank () << " : pre-converged" << std::endl;
403
380
}
381
+ } else {
382
+ converged_num = 0 ;
404
383
}
405
384
return false ;
406
385
}
0 commit comments