Skip to content

Commit fa830a1

Browse files
committed
expression suite: fixups for contractions w/ zero-volume contraction range and w/ nonzero-volume contraction range producing zero-volume result
1 parent e1883fe commit fa830a1

File tree

4 files changed

+146
-83
lines changed

4 files changed

+146
-83
lines changed

src/TiledArray/dist_eval/contraction_eval.h

Lines changed: 112 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -889,45 +889,63 @@ class Summa
889889

890890
/// Initialize reduce tasks and construct broadcast groups
891891
ordinal_type initialize(const DenseShape&) {
892-
// Construct static broadcast groups for dense arguments
893-
const madness::DistributedID col_did(DistEvalImpl_::id(), 0ul);
894-
if (k_ > 0) col_group_ = proc_grid_.make_col_group(col_did);
895-
const madness::DistributedID row_did(DistEvalImpl_::id(), k_);
896-
if (k_ > 0) row_group_ = proc_grid_.make_row_group(row_did);
892+
// if contraction is over zero-volume range just initialize tiles to zero
893+
if (k_ == 0) {
894+
ordinal_type tile_count = 0;
895+
const auto& tiles_range = this->trange().tiles_range();
896+
for (auto&& tile_idx : tiles_range) {
897+
auto tile_ord = tiles_range.ordinal(tile_idx);
898+
if (this->is_local(tile_ord)) {
899+
this->world().taskq.add([this, tile_ord, tile_idx]() {
900+
this->set_tile(tile_ord,
901+
value_type(this->trange().tile(tile_idx),
902+
typename value_type::value_type{}));
903+
});
904+
++tile_count;
905+
}
906+
}
907+
return tile_count;
908+
} else {
909+
// Construct static broadcast groups for dense arguments
910+
const madness::DistributedID col_did(DistEvalImpl_::id(), 0ul);
911+
col_group_ = proc_grid_.make_col_group(col_did);
912+
const madness::DistributedID row_did(DistEvalImpl_::id(), k_);
913+
row_group_ = proc_grid_.make_row_group(row_did);
897914

898915
#ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
899-
std::stringstream ss;
900-
ss << "init: rank=" << TensorImpl_::world().rank() << "\n col_group_=("
901-
<< col_did.first << ", " << col_did.second << ") { ";
902-
for (ProcessID gproc = 0ul; gproc < col_group_.size(); ++gproc)
903-
ss << col_group_.world_rank(gproc) << " ";
904-
ss << "}\n row_group_=(" << row_did.first << ", " << row_did.second
905-
<< ") { ";
906-
for (ProcessID gproc = 0ul; gproc < row_group_.size(); ++gproc)
907-
ss << row_group_.world_rank(gproc) << " ";
908-
ss << "}\n";
909-
printf(ss.str().c_str());
916+
std::stringstream ss;
917+
ss << "init: rank=" << TensorImpl_::world().rank() << "\n col_group_=("
918+
<< col_did.first << ", " << col_did.second << ") { ";
919+
for (ProcessID gproc = 0ul; gproc < col_group_.size(); ++gproc)
920+
ss << col_group_.world_rank(gproc) << " ";
921+
ss << "}\n row_group_=(" << row_did.first << ", " << row_did.second
922+
<< ") { ";
923+
for (ProcessID gproc = 0ul; gproc < row_group_.size(); ++gproc)
924+
ss << row_group_.world_rank(gproc) << " ";
925+
ss << "}\n";
926+
printf(ss.str().c_str());
910927
#endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
911928

912-
// Allocate memory for the reduce pair tasks.
913-
std::allocator<ReducePairTask<op_type>> alloc;
914-
reduce_tasks_ = alloc.allocate(proc_grid_.local_size());
929+
// Allocate memory for the reduce pair tasks.
930+
std::allocator<ReducePairTask<op_type>> alloc;
931+
reduce_tasks_ = alloc.allocate(proc_grid_.local_size());
915932

916-
// Iterate over all local tiles
917-
const ordinal_type n = proc_grid_.local_size();
918-
for (ordinal_type t = 0ul; t < n; ++t) {
919-
// Initialize the reduction task
920-
ReducePairTask<op_type>* MADNESS_RESTRICT const reduce_task =
921-
reduce_tasks_ + t;
922-
new (reduce_task) ReducePairTask<op_type>(TensorImpl_::world(), op_
933+
// Iterate over all local tiles
934+
const ordinal_type n = proc_grid_.local_size();
935+
for (ordinal_type t = 0ul; t < n; ++t) {
936+
// Initialize the reduction task
937+
ReducePairTask<op_type>* MADNESS_RESTRICT const reduce_task =
938+
reduce_tasks_ + t;
939+
new (reduce_task) ReducePairTask<op_type>(TensorImpl_::world(), op_
923940
#ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
924-
,
925-
nullptr, t
941+
,
942+
nullptr, t
926943
#endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
927-
);
928-
}
944+
);
945+
}
929946

930-
return proc_grid_.local_size();
947+
return proc_grid_.local_size();
948+
}
931949
}
932950

933951
/// Initialize reduce tasks
@@ -938,6 +956,9 @@ class Summa
938956
ss << " initialize rank=" << TensorImpl_::world().rank() << " tiles={ ";
939957
#endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
940958

959+
// fast return if there is no work to do
960+
if (k_ == 0) return 0;
961+
941962
// Allocate memory for the reduce pair tasks.
942963
std::allocator<ReducePairTask<op_type>> alloc;
943964
reduce_tasks_ = alloc.allocate(proc_grid_.local_size());
@@ -1705,60 +1726,79 @@ class Summa
17051726
std::max(ProcGrid::size_type(2),
17061727
std::min(proc_grid_.proc_rows(), proc_grid_.proc_cols()));
17071728

1708-
// corner case: empty result
1709-
if (k_ == 0) return 0;
1710-
1711-
// Construct the first SUMMA iteration task
1712-
if (TensorImpl_::shape().is_dense()) {
1713-
// We cannot have more iterations than there are blocks in the k
1714-
// dimension
1715-
if (depth > k_) depth = k_;
1716-
1717-
// Modify the number of concurrent iterations based on the available
1718-
// memory.
1719-
depth = mem_bound_depth(depth, 0.0f, 0.0f);
1720-
1721-
// Enforce user defined depth bound
1722-
if (max_depth_) depth = std::min(depth, max_depth_);
1723-
1724-
TensorImpl_::world().taskq.add(
1725-
new DenseStepTask(shared_from_this(), depth));
1726-
} else {
1727-
// Increase the depth based on the amount of sparsity in an iteration.
1729+
// watch out for the corner case: contraction over zero-volume range
1730+
// producing nonzero-volume result ... in that case there is nothing to do
1731+
// the appropriate initialization was performed in the initialize() method
1732+
if (k_ != 0) {
1733+
// Construct the first SUMMA iteration task
1734+
if (TensorImpl_::shape().is_dense()) {
1735+
// We cannot have more iterations than there are blocks in the k
1736+
// dimension
1737+
if (depth > k_) depth = k_;
1738+
1739+
// Modify the number of concurrent iterations based on the available
1740+
// memory.
1741+
depth = mem_bound_depth(depth, 0.0f, 0.0f);
1742+
1743+
// Enforce user defined depth bound
1744+
if (max_depth_) depth = std::min(depth, max_depth_);
1745+
1746+
TensorImpl_::world().taskq.add(
1747+
new DenseStepTask(shared_from_this(), depth));
1748+
} else {
1749+
// Increase the depth based on the amount of sparsity in an iteration.
17281750

1729-
// Get the sparsity fractions for the left- and right-hand arguments.
1730-
const float left_sparsity = left_.shape().sparsity();
1731-
const float right_sparsity = right_.shape().sparsity();
1751+
// Get the sparsity fractions for the left- and right-hand arguments.
1752+
const float left_sparsity = left_.shape().sparsity();
1753+
const float right_sparsity = right_.shape().sparsity();
17321754

1733-
// Compute the fraction of non-zero result tiles in a single SUMMA
1734-
// iteration.
1735-
const float frac_non_zero = (1.0f - std::min(left_sparsity, 0.9f)) *
1736-
(1.0f - std::min(right_sparsity, 0.9f));
1755+
// Compute the fraction of non-zero result tiles in a single SUMMA
1756+
// iteration.
1757+
const float frac_non_zero = (1.0f - std::min(left_sparsity, 0.9f)) *
1758+
(1.0f - std::min(right_sparsity, 0.9f));
17371759

1738-
// Compute the new depth based on sparsity of the arguments
1739-
depth =
1740-
float(depth) * (1.0f - 1.35638f * std::log2(frac_non_zero)) + 0.5f;
1760+
// Compute the new depth based on sparsity of the arguments
1761+
depth = float(depth) * (1.0f - 1.35638f * std::log2(frac_non_zero)) +
1762+
0.5f;
17411763

1742-
// We cannot have more iterations than there are blocks in the k
1743-
// dimension
1744-
if (depth > k_) depth = k_;
1764+
// We cannot have more iterations than there are blocks in the k
1765+
// dimension
1766+
if (depth > k_) depth = k_;
17451767

1746-
// Modify the number of concurrent iterations based on the available
1747-
// memory and sparsity of the argument tensors.
1748-
depth = mem_bound_depth(depth, left_sparsity, right_sparsity);
1768+
// Modify the number of concurrent iterations based on the available
1769+
// memory and sparsity of the argument tensors.
1770+
depth = mem_bound_depth(depth, left_sparsity, right_sparsity);
17491771

1750-
// Enforce user defined depth bound
1751-
if (max_depth_) depth = std::min(depth, max_depth_);
1772+
// Enforce user defined depth bound
1773+
if (max_depth_) depth = std::min(depth, max_depth_);
17521774

1753-
TensorImpl_::world().taskq.add(
1754-
new SparseStepTask(shared_from_this(), depth));
1755-
}
1775+
TensorImpl_::world().taskq.add(
1776+
new SparseStepTask(shared_from_this(), depth));
1777+
}
1778+
} // k_ != 0
17561779
}
17571780

17581781
#ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
17591782
printf("eval: start wait children rank=%i\n", TensorImpl_::world().rank());
17601783
#endif // TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL
17611784

1785+
// corner case: if left or right are zero-volume no tasks were scheduled, so
1786+
// need to discard all of their tiles manually
1787+
if (left_.range().volume() == 0) {
1788+
for (auto&& tile_idx : right_.range()) {
1789+
auto tile_ord = right_.range().ordinal(tile_idx);
1790+
if (right_.is_local(tile_ord) && !right_.is_zero(tile_ord))
1791+
right_.discard(tile_ord);
1792+
}
1793+
}
1794+
if (right_.range().volume() == 0) {
1795+
for (auto&& tile_idx : left_.range()) {
1796+
auto tile_ord = left_.range().ordinal(tile_idx);
1797+
if (left_.is_local(tile_ord) && !left_.is_zero(tile_ord))
1798+
left_.discard(tile_ord);
1799+
}
1800+
}
1801+
17621802
// Wait for child tensors to be evaluated, and process tasks while waiting.
17631803
left_.wait();
17641804
right_.wait();

src/TiledArray/expressions/cont_engine.h

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,26 @@ class ContEngine : public BinaryEngine<Derived> {
343343
n *= right_element_size[i];
344344
}
345345

346-
// Construct the process grid.
347-
proc_grid_ = TiledArray::detail::ProcGrid(*world, M, N, m, n);
348-
349-
// Initialize children
350-
left_.init_distribution(world, proc_grid_.make_row_phase_pmap(K_));
351-
right_.init_distribution(world, proc_grid_.make_col_phase_pmap(K_));
352-
353-
// Initialize the process map if not already defined
354-
if (!pmap) pmap = proc_grid_.make_pmap();
355-
ExprEngine_::init_distribution(world, pmap);
346+
// corner case: zero-volume result ... easier to skip proc_grid_
347+
// construction alltogether
348+
if (M == 0 || N == 0) {
349+
left_.init_distribution(world, {});
350+
right_.init_distribution(world, {});
351+
ExprEngine_::init_distribution(
352+
world, (pmap ? pmap : policy::default_pmap(*world, M * N)));
353+
} else { // M!=0 && N!=0
354+
355+
// Construct the process grid.
356+
proc_grid_ = TiledArray::detail::ProcGrid(*world, M, N, m, n);
357+
358+
// Initialize children
359+
left_.init_distribution(world, proc_grid_.make_row_phase_pmap(K_));
360+
right_.init_distribution(world, proc_grid_.make_col_phase_pmap(K_));
361+
362+
// Initialize the process map if not already defined
363+
if (!pmap) pmap = proc_grid_.make_pmap();
364+
ExprEngine_::init_distribution(world, pmap);
365+
}
356366
}
357367

358368
/// Tiled range factory function

tests/expressions_fixture.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ struct ExpressionsFixture : public TiledRangeFixture {
6363
s_tr1_2(make_random_sparseshape(trange1)),
6464
s_tr2(make_random_sparseshape(trange2)),
6565
s_trC(make_random_sparseshape(trangeC)),
66+
s_trC_f(make_random_sparseshape(trangeC_f)),
6667
a(*GlobalFixture::world, tr, s_tr_1),
6768
b(*GlobalFixture::world, tr, s_tr_2),
6869
c(*GlobalFixture::world, tr, s_tr_2),
6970
aC(*GlobalFixture::world, trangeC, s_trC),
71+
aC_f(*GlobalFixture::world, trangeC_f, s_trC_f),
7072
u(*GlobalFixture::world, trange1, s_tr1_1),
7173
v(*GlobalFixture::world, trange1, s_tr1_2),
7274
w(*GlobalFixture::world, trange2, s_tr2) {
@@ -92,12 +94,14 @@ struct ExpressionsFixture : public TiledRangeFixture {
9294
u(*GlobalFixture::world, trange1),
9395
v(*GlobalFixture::world, trange1),
9496
w(*GlobalFixture::world, trange2),
95-
aC(*GlobalFixture::world, trangeC) {
97+
aC(*GlobalFixture::world, trangeC),
98+
aC_f(*GlobalFixture::world, trangeC_f) {
9699
random_fill(a);
97100
random_fill(b);
98101
random_fill(u);
99102
random_fill(v);
100103
random_fill(aC);
104+
random_fill(aC_f);
101105
GlobalFixture::world->gop.fence();
102106
}
103107

@@ -221,19 +225,25 @@ struct ExpressionsFixture : public TiledRangeFixture {
221225
// contains empty trange1
222226
const TiledRange trangeC{TiledRange1{0, 2, 5, 10}, TiledRange1{},
223227
TiledRange1{0, 2, 7, 11}};
228+
// like trC, but with all dimension nonempty
229+
const TiledRange trangeC_f{trangeC.dim(0), TiledRange1{0, 4, 7},
230+
trangeC.dim(2)};
231+
224232
SparseShape<float> s_tr_1;
225233
SparseShape<float> s_tr_2;
226234
SparseShape<float> s_tr1_1;
227235
SparseShape<float> s_tr1_2;
228236
SparseShape<float> s_tr2;
229237
SparseShape<float> s_trC;
238+
SparseShape<float> s_trC_f;
230239
TArray a;
231240
TArray b;
232241
TArray c;
233242
TArray u;
234243
TArray v;
235244
TArray w;
236245
TArray aC;
246+
TArray aC_f;
237247
}; // ExpressionsFixture
238248

239249
#endif // TILEDARRAY_TEST_EXPRESSIONS_FIXTURE_H

tests/expressions_impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2946,6 +2946,7 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(inner_product, F, Fixtures, F) {
29462946
BOOST_FIXTURE_TEST_CASE_TEMPLATE(empty_trange1, F, Fixtures, F) {
29472947
auto& c = F::c;
29482948
auto& aC = F::aC;
2949+
auto& aC_f = F::aC_f;
29492950

29502951
// unary/binary expressions
29512952
{
@@ -2981,6 +2982,8 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(empty_trange1, F, Fixtures, F) {
29812982
BOOST_CHECK_NO_THROW(t2("a,d") = aC("a,b,c") * aC("d,b,c"));
29822983
// contraction over nonempty dims
29832984
BOOST_CHECK_NO_THROW(t4("b,a,e,d") = aC("a,b,c") * aC("d,e,c"));
2985+
// contraction over nonempty dims, involving expressions with nonzero-volume
2986+
BOOST_CHECK_NO_THROW(t4("b,a,e,d") = aC("a,b,c") * (2. * aC_f("d,e,c")));
29842987
}
29852988

29862989
// reduction expressions

0 commit comments

Comments
 (0)