Skip to content

Commit 9c2a18d

Browse files
authored
Merge pull request #394 from ValeevGroup/evaleev/feature/reentrant-rand
reentrant `TA::rand()`
2 parents c5a7a45 + 75a97df commit 9c2a18d

20 files changed

+501
-275
lines changed

INSTALL.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Both methods are supported. However, for most users we _strongly_ recommend to b
4040
- Boost.Container: header-only
4141
- Boost.Test: header-only or (optionally) as a compiled library, *only used for unit testing*
4242
- Boost.Range: header-only, *only used for unit testing*
43-
- [BTAS](http://github.com/ValeevGroup/BTAS), tag 474ddc095cbea12a1d28aca5435703dd9f69b166 . If usable BTAS installation is not found, TiledArray will download and compile
43+
- [BTAS](http://github.com/ValeevGroup/BTAS), tag 6fcb6451bc7ca46a00534a30c51dc5c230c39ac3 . If usable BTAS installation is not found, TiledArray will download and compile
4444
BTAS from source. *This is the recommended way to compile BTAS for all users*.
4545
- [MADNESS](https://github.com/m-a-d-n-e-s-s/madness), tag 0b44ef319643cb9721fbe17d294987c146e6460e .
4646
Only the MADworld runtime and BLAS/LAPACK C API component of MADNESS is used by TiledArray.

external/versions.cmake

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ set(TA_TRACKED_MADNESS_PREVIOUS_TAG 29a2bf3d3c2670c608b7bfdf2299d76fbc20e041)
2424
set(TA_TRACKED_MADNESS_VERSION 0.10.1)
2525
set(TA_TRACKED_MADNESS_PREVIOUS_VERSION 0.10.1)
2626

27-
set(TA_TRACKED_BTAS_TAG 474ddc095cbea12a1d28aca5435703dd9f69b166)
28-
set(TA_TRACKED_BTAS_PREVIOUS_TAG 2917aa21465a93ae6f399874f247b5fe31d6b693)
27+
set(TA_TRACKED_BTAS_TAG 6fcb6451bc7ca46a00534a30c51dc5c230c39ac3)
28+
set(TA_TRACKED_BTAS_PREVIOUS_TAG 474ddc095cbea12a1d28aca5435703dd9f69b166)
2929

3030
set(TA_TRACKED_LIBRETT_TAG 68abe31a9ec6fd2fd9ffbcd874daa80457f947da)
3131
set(TA_TRACKED_LIBRETT_PREVIOUS_TAG 7e27ac766a9038df6aa05613784a54a036c4b796)

src/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,12 @@ TiledArray/util/initializer_list.h
195195
TiledArray/util/logger.h
196196
TiledArray/util/ptr_registry.cpp
197197
TiledArray/util/ptr_registry.h
198+
TiledArray/util/random.cpp
198199
TiledArray/util/random.h
199200
TiledArray/util/singleton.h
200201
TiledArray/util/threads.h
201202
TiledArray/util/threads.cpp
203+
TiledArray/util/thread_specific.h
202204
TiledArray/util/time.h
203205
TiledArray/util/vector.h
204206
)

src/TiledArray/conversions/eigen.h

+3-5
Original file line numberDiff line numberDiff line change
@@ -963,10 +963,10 @@ DistArray_ eigen_tensor_to_array(
963963
/// replicated TiledArray::DistArray. Usage:
964964
/// \code
965965
/// TiledArray::TArrayD
966-
/// array(world, trange);
966+
/// array(world, trange_3d);
967967
/// // Set tiles of array ...
968968
///
969-
/// auto t = array_to_eigen_tensor(array);
969+
/// auto t = array_to_eigen_tensor<Eigen::Tensor<double, 3>(array);
970970
/// \endcode
971971
/// \tparam Tile the tile type of \c src
972972
/// \tparam Policy the policy type of \c src
@@ -980,13 +980,11 @@ DistArray_ eigen_tensor_to_array(
980980
/// create the Eigen::Tensor on every rank (this requires
981981
/// that \c src.is_replicated()==true )
982982
/// \return Eigen::Tensor object containing the data of \c src , if my rank
983-
/// equals
984-
/// \c target_rank or \c target_rank==-1 ,
983+
/// equals \c target_rank or \c target_rank==-1 ,
985984
/// default-initialized Eigen::Tensor otherwise.
986985
template <typename Tensor, typename Tile, typename Policy>
987986
Tensor array_to_eigen_tensor(const TiledArray::DistArray<Tile, Policy>& src,
988987
int target_rank = -1) {
989-
990988
TA_ASSERT(src.tiles_range().rank() == Tensor::NumDimensions);
991989

992990
// Test preconditions

src/TiledArray/cp/cp.h

+29-50
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,10 @@
88
#include <TiledArray/conversions/btas.h>
99
#include <TiledArray/expressions/einsum.h>
1010
#include <tiledarray.h>
11-
#include <random>
1211

1312
namespace TiledArray::cp {
13+
1414
namespace detail {
15-
// A seed for the random number generator.
16-
static inline unsigned int& random_seed_accessor() {
17-
static unsigned int value = 3;
18-
return value;
19-
}
20-
21-
// given a rank and block size, this computes a
22-
// trange for the rank dimension to be used to make the CP factors.
23-
static inline TiledRange1 compute_trange1(size_t rank, size_t rank_block_size) {
24-
std::size_t nblocks = (rank + rank_block_size - 1) / rank_block_size;
25-
auto dv = std::div((int)(rank + nblocks - 1), (int)nblocks);
26-
auto avg_block_size = dv.quot - 1, num_avg_plus_one = dv.rem + 1;
27-
28-
TiledArray::TiledRange1 new_trange1;
29-
{
30-
std::vector<std::size_t> new_trange1_v;
31-
new_trange1_v.reserve(nblocks + 1);
32-
auto block_counter = 0;
33-
for (auto i = 0; i < num_avg_plus_one;
34-
++i, block_counter += avg_block_size + 1) {
35-
new_trange1_v.emplace_back(block_counter);
36-
}
37-
for (auto i = num_avg_plus_one; i < nblocks;
38-
++i, block_counter += avg_block_size) {
39-
new_trange1_v.emplace_back(block_counter);
40-
}
41-
new_trange1_v.emplace_back(rank);
42-
new_trange1 =
43-
TiledArray::TiledRange1(new_trange1_v.begin(), new_trange1_v.end());
44-
}
45-
return new_trange1;
46-
}
4715

4816
static inline char intToAlphabet(int i) { return static_cast<char>('a' + i); }
4917

@@ -111,13 +79,13 @@ class CP {
11179
if (build_rank) {
11280
size_t cur_rank = 1;
11381
do {
114-
rank_trange = detail::compute_trange1(cur_rank, rank_block_size);
82+
rank_trange = TiledArray::compute_trange1(cur_rank, rank_block_size);
11583
build_guess(cur_rank, rank_trange);
11684
ALS(cur_rank, 100, verbose);
11785
++cur_rank;
11886
} while (cur_rank < rank);
11987
} else {
120-
rank_trange = detail::compute_trange1(rank, rank_block_size);
88+
rank_trange = TiledArray::compute_trange1(rank, rank_block_size);
12189
build_guess(rank, rank_trange);
12290
ALS(rank, 100, verbose);
12391
}
@@ -143,7 +111,7 @@ class CP {
143111
double epsilon = 1.0;
144112
fit_tol = epsilonALS;
145113
do {
146-
auto rank_trange = detail::compute_trange1(cur_rank, rank_block_size);
114+
auto rank_trange = compute_trange1(cur_rank, rank_block_size);
147115
build_guess(cur_rank, rank_trange);
148116
ALS(cur_rank, 100, verbose);
149117
++cur_rank;
@@ -196,9 +164,10 @@ class CP {
196164
final_fit, // The final fit of the ALS
197165
// optimization at fixed rank.
198166
fit_tol, // Tolerance for the ALS solver
199-
converged_num, // How many times the ALS solver
200-
// has changed less than the tolerance
201167
norm_reference; // used in determining the CP fit.
168+
std::size_t converged_num =
169+
0; // How many times the ALS solver
170+
// has changed less than the tolerance in a row
202171

203172
/// This function is determined by the specific CP solver.
204173
/// builds the rank @c rank CP approximation and stores
@@ -227,14 +196,12 @@ class CP {
227196
auto lambda = std::vector<typename Tile::value_type>(
228197
rank, (typename Tile::value_type)0);
229198
if (world.rank() == 0) {
230-
std::mt19937 generator(detail::random_seed_accessor());
231-
std::uniform_real_distribution<> distribution(-1.0, 1.0);
232199
auto factor_ptr = factor.data();
233200
size_t offset = 0;
234201
for (auto r = 0; r < rank; ++r, offset += mode_size) {
235202
auto lam_ptr = lambda.data() + r;
236203
for (auto m = offset; m < offset + mode_size; ++m) {
237-
auto val = distribution(generator);
204+
auto val = TiledArray::drand() * 2 - 1; // random number in [-1,1]
238205
*(factor_ptr + m) = val;
239206
*lam_ptr += val * val;
240207
}
@@ -364,7 +331,7 @@ class CP {
364331
/// \returns bool : is the change in fit less than the ALS tolerance?
365332
virtual bool check_fit(bool verbose = false) {
366333
// Compute the inner product T * T_CP
367-
double inner_prod = MTtKRP("r,n").dot(unNormalized_Factor("r,n"));
334+
const auto ref_dot_cp = MTtKRP("r,n").dot(unNormalized_Factor("r,n"));
368335
// compute the square of the CP tensor (can use the grammian)
369336
auto factor_norm = [&]() {
370337
auto gram_ptr = partial_grammian.begin();
@@ -380,27 +347,39 @@ class CP {
380347
return result;
381348
};
382349
// compute the error in the loss function and find the fit
383-
double normFactors = factor_norm(),
384-
normResidual =
385-
sqrt(abs(norm_reference * norm_reference +
386-
normFactors * normFactors - 2.0 * inner_prod)),
387-
fit = 1.0 - (normResidual / norm_reference),
388-
fit_change = abs(prev_fit - fit);
350+
const auto norm_cp = factor_norm(); // ||T_CP||_2
351+
const auto squared_norm_error = norm_reference * norm_reference +
352+
norm_cp * norm_cp -
353+
2.0 * ref_dot_cp; // ||T - T_CP||_2^2
354+
// N.B. squared_norm_error is very noisy
355+
// TA_ASSERT(squared_norm_error >= - 1e-8);
356+
const auto norm_error = sqrt(abs(squared_norm_error));
357+
const auto fit = 1.0 - (norm_error / norm_reference);
358+
const auto fit_change = fit - prev_fit;
389359
prev_fit = fit;
390360
// print fit data if required
391361
if (verbose) {
392-
std::cout << fit << "\t" << fit_change << std::endl;
362+
std::cout << MTtKRP.world().rank() << ": fit=" << fit
363+
<< " fit_change=" << fit_change << std::endl;
393364
}
394365

395366
// if the change in fit is less than the tolerance try to return true.
396-
if (fit_change < fit_tol) {
367+
if (abs(fit_change) < fit_tol) {
397368
converged_num++;
398369
if (converged_num == 2) {
399370
converged_num = 0;
400371
final_fit = prev_fit;
401372
prev_fit = 1.0;
373+
if (verbose)
374+
std::cout << MTtKRP.world().rank() << ": converged" << std::endl;
402375
return true;
376+
} else {
377+
TA_ASSERT(converged_num == 1);
378+
if (verbose)
379+
std::cout << MTtKRP.world().rank() << ": pre-converged" << std::endl;
403380
}
381+
} else {
382+
converged_num = 0;
404383
}
405384
return false;
406385
}

src/TiledArray/cp/cp_als.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
* takes a reference order-N tensor and decomposes it into a
1313
* set of order-2 tensors all coupled by a hyperdimension called the rank.
1414
* These factors are optimized using an alternating least squares
15-
* algorithm. This class is derived form the base CP class
15+
* algorithm.
1616
*
1717
* @tparam Tile typing for the DistArray tiles
1818
* @tparam Policy policy of the DistArray

src/TiledArray/dist_array.h

+18-17
Original file line numberDiff line numberDiff line change
@@ -890,20 +890,16 @@ class DistArray : public madness::archive::ParallelSerializableObject {
890890
///
891891
/// \tparam T The type of random value to generate. Defaults to
892892
/// element_type.
893-
/// \tparam <anonymous> A template type parameter which will be deduced as
894-
/// void only if MakeRandom knows how to generate random
895-
/// values of type T. If MakeRandom does not know how to
896-
/// generate random values of type T, SFINAE will disable
897-
/// this function.
898893
/// \param[in] skip_set If false, will throw if any tiles are already set
899894
/// \throw TiledArray::Exception if the PIMPL is not initialized. Strong
900895
/// throw guarantee.
901896
/// \throw TiledArray::Exception if skip_set is false and a local tile is
902897
/// already initialized. Weak throw guarantee.
903-
template <typename T = element_type,
898+
template <HostExecutor Exec = HostExecutor::Default,
899+
typename T = element_type,
904900
typename = detail::enable_if_can_make_random_t<T>>
905901
void fill_random(bool skip_set = false) {
906-
init_elements(
902+
init_elements<Exec>(
907903
[](const auto&) { return detail::MakeRandom<T>::generate_value(); });
908904
}
909905

@@ -943,7 +939,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
943939
/// guarantee.
944940
/// \throw TiledArray::Exception if a tile is already set and skip_set is
945941
/// false. Weak throw guarantee.
946-
template <typename Op>
942+
template <HostExecutor Exec = HostExecutor::Default, typename Op>
947943
void init_tiles(Op&& op, bool skip_set = false) {
948944
// lifetime management of op depends on whether it is a lvalue ref (i.e. has
949945
// an external owner) or an rvalue ref
@@ -957,15 +953,20 @@ class DistArray : public madness::archive::ParallelSerializableObject {
957953
const auto& index = *it;
958954
if (!pimpl_->is_zero(index)) {
959955
if (skip_set) {
960-
auto fut = find(index);
956+
auto fut = find_local(index);
961957
if (fut.probe()) continue;
962958
}
963-
Future<value_type> tile = pimpl_->world().taskq.add(
964-
[pimpl = pimpl_, index = ordinal_type(index),
965-
op_shared_handle]() -> value_type {
966-
return op_shared_handle(pimpl->trange().make_tile_range(index));
967-
});
968-
set(index, std::move(tile));
959+
if constexpr (Exec == HostExecutor::MADWorld) {
960+
Future<value_type> tile = pimpl_->world().taskq.add(
961+
[pimpl = pimpl_, index = ordinal_type(index),
962+
op_shared_handle]() -> value_type {
963+
return op_shared_handle(pimpl->trange().make_tile_range(index));
964+
});
965+
set(index, std::move(tile));
966+
} else {
967+
static_assert(Exec == HostExecutor::Thread);
968+
set(index, op_shared_handle(trange().make_tile_range(index)));
969+
}
969970
}
970971
}
971972
}
@@ -994,10 +995,10 @@ class DistArray : public madness::archive::ParallelSerializableObject {
994995
/// \throw TiledArray::Exception if skip_set is false and a local, non-zero
995996
/// tile is already initialized. Weak throw
996997
/// guarantee.
997-
template <typename Op>
998+
template <HostExecutor Exec = HostExecutor::Default, typename Op>
998999
void init_elements(Op&& op, bool skip_set = false) {
9991000
auto op_shared_handle = make_op_shared_handle(std::forward<Op>(op));
1000-
init_tiles(
1001+
init_tiles<Exec>(
10011002
[op = std::move(op_shared_handle)](
10021003
const TiledArray::Range& range) -> value_type {
10031004
// Initialize the tile with the given range object

src/TiledArray/einsum/tiledarray.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ auto einsum(expressions::TsrExpr<Array_> A, expressions::TsrExpr<Array_> B,
225225
if (!term.array.is_local(idx)) continue;
226226
if (term.array.is_zero(idx)) continue;
227227
// TODO no need for immediate evaluation
228-
auto tile = term.array.find(idx).get();
228+
auto tile = term.array.find_local(idx).get();
229229
if (P) tile = tile.permute(P);
230230
auto shape = term.ei_tiled_range.tile(ei);
231231
tile = tile.reshape(shape, batch);
@@ -247,7 +247,7 @@ auto einsum(expressions::TsrExpr<Array_> A, expressions::TsrExpr<Array_> B,
247247
if (!C.ei.is_local(e)) continue;
248248
if (C.ei.is_zero(e)) continue;
249249
// TODO no need for immediate evaluation
250-
auto tile = C.ei.find(e).get();
250+
auto tile = C.ei.find_local(e).get();
251251
assert(tile.batch_size() == batch);
252252
const Permutation &P = C.permutation;
253253
auto c = apply(P, h + e);

src/TiledArray/fwd.h

+2
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ using Array
149149
[[deprecated("use TiledArray::DistArray or TiledArray::TArray<T>")]] =
150150
DistArray<Tile, Policy>;
151151

152+
enum class HostExecutor { Thread, MADWorld, Default = MADWorld };
153+
152154
} // namespace TiledArray
153155

154156
#ifndef TILEDARRAY_DISABLE_NAMESPACE_TA

src/TiledArray/tensor.h

+14-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,20 @@ template <typename T, typename std::enable_if<detail::is_tensor<T>::value &&
6262
inline std::ostream& operator<<(std::ostream& os, const T& t) {
6363
os << t.range() << " { ";
6464
const auto n = t.range().volume();
65-
for (auto ord = 0ul; ord < n; ++ord) os << t.at_ordinal(ord) << " ";
66-
65+
std::size_t offset = 0ul;
66+
const auto more_than_1_batch = t.batch_size() > 1;
67+
for (auto b = 0ul; b != t.batch_size(); ++b) {
68+
if (more_than_1_batch) {
69+
os << "[batch " << b << "]{ ";
70+
}
71+
for (auto ord = 0ul; ord < n; ++ord) {
72+
os << t.data()[offset + ord] << " ";
73+
}
74+
if (more_than_1_batch) {
75+
os << "} ";
76+
}
77+
offset += n;
78+
}
6779
os << "}";
6880

6981
return os;

0 commit comments

Comments
 (0)