Skip to content

Commit f6ec0db

Browse files
committed
TA::einsum replicated tests
1 parent 0a92227 commit f6ec0db

File tree

2 files changed

+99
-79
lines changed

2 files changed

+99
-79
lines changed

tests/CMakeLists.txt

+63-79
Original file line numberDiff line numberDiff line change
@@ -33,86 +33,70 @@ set(executable ta_test)
3333
# N.B. 2: if you want to trim this down you may need to resolve linker errors due to missing fixture deps manually
3434
set(ta_test_src_files ta_test.cpp
3535
range.cpp
36-
btas.cpp
37-
meta.cpp
38-
type_traits.cpp
39-
utility.cpp
40-
permutation.cpp
41-
symm_permutation_group.cpp
42-
symm_irrep.cpp
43-
symm_representation.cpp
44-
block_range.cpp
45-
perm_index.cpp
46-
transform_iterator.cpp
47-
bitset.cpp
48-
math_outer.cpp
49-
math_partial_reduce.cpp
50-
math_transpose.cpp
51-
math_blas.cpp
52-
tensor.cpp
53-
tensor_of_tensor.cpp
54-
tensor_tensor_view.cpp
55-
tensor_shift_wrapper.cpp
56-
tiled_range1.cpp
57-
tiled_range.cpp
58-
blocked_pmap.cpp
59-
round_robin_pmap.cpp
60-
hash_pmap.cpp
61-
cyclic_pmap.cpp
62-
replicated_pmap.cpp
63-
dense_shape.cpp
64-
sparse_shape.cpp
65-
distributed_storage.cpp
66-
tensor_impl.cpp
67-
array_impl.cpp
68-
index_list.cpp
69-
bipartite_index_list.cpp
70-
dist_array.cpp
71-
conversions.cpp
72-
eigen.cpp
73-
dist_op_dist_cache.cpp
74-
dist_op_group.cpp
75-
dist_op_communicator.cpp
76-
tile_op_noop.cpp
77-
tile_op_scal.cpp
78-
dist_eval_array_eval.cpp
79-
dist_eval_unary_eval.cpp
80-
tile_op_add.cpp
81-
tile_op_scal_add.cpp
82-
tile_op_subt.cpp
83-
tile_op_scal_subt.cpp
84-
dist_eval_binary_eval.cpp
85-
tile_op_mult.cpp
86-
tile_op_scal_mult.cpp
87-
tile_op_contract_reduce.cpp
88-
reduce_task.cpp
89-
proc_grid.cpp
90-
dist_eval_contraction_eval.cpp
91-
expressions.cpp
92-
expressions_sparse.cpp
93-
expressions_complex.cpp
94-
expressions_btas.cpp
95-
expressions_mixed.cpp
96-
foreach.cpp
97-
solvers.cpp
98-
initializer_list.cpp
99-
diagonal_array.cpp
100-
retile.cpp
101-
tot_dist_array_part1.cpp
102-
tot_dist_array_part2.cpp
103-
random.cpp
104-
trace.cpp
105-
tot_expressions.cpp
106-
annotation.cpp
107-
diagonal_array.cpp
108-
contraction_helpers.cpp
109-
s_t_t_contract_.cpp
110-
t_t_t_contract_.cpp
111-
t_s_t_contract_.cpp
112-
# t_tot_tot_contract_.cpp
113-
# tot_tot_tot_contract_.cpp
36+
# tensor.cpp
37+
# tensor_of_tensor.cpp
38+
# tensor_tensor_view.cpp
39+
# tensor_shift_wrapper.cpp
40+
# tiled_range1.cpp
41+
# tiled_range.cpp
42+
# blocked_pmap.cpp
43+
# round_robin_pmap.cpp
44+
# hash_pmap.cpp
45+
# cyclic_pmap.cpp
46+
# replicated_pmap.cpp
47+
# dense_shape.cpp
48+
# sparse_shape.cpp
49+
# distributed_storage.cpp
50+
# tensor_impl.cpp
51+
# array_impl.cpp
52+
# index_list.cpp
53+
# bipartite_index_list.cpp
54+
# dist_array.cpp
55+
# conversions.cpp
56+
# eigen.cpp
57+
# dist_op_dist_cache.cpp
58+
# dist_op_group.cpp
59+
# dist_op_communicator.cpp
60+
# tile_op_noop.cpp
61+
# tile_op_scal.cpp
62+
# dist_eval_array_eval.cpp
63+
# dist_eval_unary_eval.cpp
64+
# tile_op_add.cpp
65+
# tile_op_scal_add.cpp
66+
# tile_op_subt.cpp
67+
# tile_op_scal_subt.cpp
68+
# dist_eval_binary_eval.cpp
69+
# tile_op_mult.cpp
70+
# tile_op_scal_mult.cpp
71+
# tile_op_contract_reduce.cpp
72+
# reduce_task.cpp
73+
# proc_grid.cpp
74+
# dist_eval_contraction_eval.cpp
75+
# expressions.cpp
76+
# expressions_sparse.cpp
77+
# expressions_complex.cpp
78+
# expressions_btas.cpp
79+
# expressions_mixed.cpp
80+
# foreach.cpp
81+
# solvers.cpp
82+
# initializer_list.cpp
83+
# diagonal_array.cpp
84+
# retile.cpp
85+
# tot_dist_array_part1.cpp
86+
# tot_dist_array_part2.cpp
87+
# random.cpp
88+
# trace.cpp
89+
# tot_expressions.cpp
90+
# annotation.cpp
91+
# diagonal_array.cpp
92+
# contraction_helpers.cpp
93+
# s_t_t_contract_.cpp
94+
# t_t_t_contract_.cpp
95+
# t_s_t_contract_.cpp
96+
# # t_tot_tot_contract_.cpp
97+
# # tot_tot_tot_contract_.cpp
11498
einsum.cpp
115-
linalg.cpp
99+
# linalg.cpp
116100
)
117101

118102
if(CUDA_FOUND)

tests/einsum.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,14 @@ auto random(Args ... args) {
334334
return t;
335335
}
336336

337+
template<typename T, typename P>
338+
auto replicated(DistArray<T,P> &&array) {
339+
array.world().gop.fence();
340+
array.make_replicated();
341+
array.world().gop.fence();
342+
return array;
343+
}
344+
337345
template<typename T, int NA, int NB, size_t NI, size_t NC>
338346
void einsum_eigen_contract_check(
339347
Eigen::Tensor<T,NA> A,
@@ -695,4 +703,32 @@ BOOST_AUTO_TEST_CASE(einsum_tiledarray_hji_jih_hj) {
695703
);
696704
}
697705

706+
BOOST_AUTO_TEST_CASE(einsum_tiledarray_replicated) {
707+
einsum_tiledarray_check<3,3,3>(
708+
replicated(random<DensePolicy>(7,14,3)),
709+
random<DensePolicy>(7,15,3),
710+
"hai,hbi->hab"
711+
);
712+
einsum_tiledarray_check<3,3,3>(
713+
random<DensePolicy>(7,14,3),
714+
replicated(random<DensePolicy>(7,15,3)),
715+
"hai,hbi->hab"
716+
);
717+
einsum_tiledarray_check<3,3,3>(
718+
replicated(random<DensePolicy>(7,14,3)),
719+
replicated(random<DensePolicy>(7,15,3)),
720+
"hai,hbi->hab"
721+
);
722+
einsum_tiledarray_check<2,2,1>(
723+
replicated(random<SparsePolicy>(7,14)),
724+
random<SparsePolicy>(7,14),
725+
"hi,hi->h"
726+
);
727+
einsum_tiledarray_check<2,2,1>(
728+
replicated(random<SparsePolicy>(7,14)),
729+
replicated(random<SparsePolicy>(7,14)),
730+
"hi,hi->h"
731+
);
732+
}
733+
698734
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)