Skip to content

Commit ec58334

Browse files
committed
disable folding for sparse * dense contraction with a Hadamard index, which corrects behavior for that case. Broaden the amount of information given by print_map()
1 parent dfa4306 commit ec58334

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

src/contraction/contraction.cxx

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,22 @@ namespace CTF_int {
409409
if ((A->order+B->order+C->order)%2 == 1 ||
410410
(A->order+B->order+C->order)/2 < nfold ){
411411
return 0;
412+
} else {
413+
// do not allow weigh indices for sparse contractions
414+
int num_tot;
415+
int * idx_arr;
416+
417+
inv_idx(A->order, idx_A,
418+
B->order, idx_B,
419+
C->order, idx_C,
420+
&num_tot, &idx_arr);
421+
for (i=0; i<num_tot; i++){
422+
if (idx_arr[3*i] != -1 && idx_arr[3*i+1] != -1 && idx_arr[3*i+2] != -1){
423+
return 0;
424+
}
425+
}
426+
CTF_int::cdealloc(idx_arr);
412427
}
413-
//FIXME:!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
414428
}
415429
CTF_int::cdealloc(fold_idx);
416430
/* FIXME: 1 folded index is good enough for now, in the future model */

src/tensor/untyped_tensor.cxx

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -644,17 +644,17 @@ namespace CTF_int {
644644

645645
void tensor::print_map(FILE * stream, bool allcall) const {
646646
if (!allcall || wrld->rank == 0){
647-
/* if (is_sparse)
647+
if (is_sparse)
648648
printf("printing mapping of sparse tensor %s\n",name);
649649
else
650-
printf("printing mapping of dense tensor %s\n",name);*/
651-
/* if (topo != NULL){
650+
printf("printing mapping of dense tensor %s\n",name);
651+
if (topo != NULL){
652652
printf("CTF: %s mapped to order %d topology with dims:",name,topo->order);
653653
for (int dim=0; dim<topo->order; dim++){
654654
printf(" %d ",topo->lens[dim]);
655655
}
656656
}
657-
printf("\n");*/
657+
printf("\n");
658658
char tname[200];
659659
tname[0] = '\0';
660660
sprintf(tname, "%s[", name);
@@ -676,13 +676,13 @@ namespace CTF_int {
676676
// sprintf(tname+strlen(tname),"c%d",edge_map[dim].has_child);
677677
}
678678
sprintf(tname+strlen(tname), "]");
679-
printf("CTF: Tensor mapping is %s\n",tname);
680-
/* printf("\nCTF: sym len tphs pphs vphs\n");
679+
/*printf("CTF: Tensor mapping is %s\n",tname);
680+
printf("\nCTF: sym len tphs pphs vphs\n");
681681
for (int dim=0; dim<order; dim++){
682682
int tp = edge_map[dim].calc_phase();
683683
int pp = edge_map[dim].calc_phys_phase();
684684
int vp = tp/pp;
685-
printf("CTF: %2s %5d %5d %5d %5d\n", SY_strings[sym[dim]], lens[dim], tp, pp, vp);
685+
printf("CTF: %5d %5d %5d %5d\n", lens[dim], tp, pp, vp);
686686
}*/
687687
}
688688
}
@@ -1818,7 +1818,9 @@ namespace CTF_int {
18181818

18191819
if (wrld->rank == 0)
18201820
printf("Printing tensor %s\n",name);
1821-
//print_map(fp);
1821+
#ifdef DEBUG
1822+
print_map(fp);
1823+
#endif
18221824

18231825
/*for (int i=0; i<this->size; i++){
18241826
printf("this->data[%d] = ",i);

0 commit comments

Comments
 (0)