Skip to content

Commit 4d62853

Browse files
committed
adding complex conjugate
1 parent a33ee3a commit 4d62853

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

distributed/worker.cxx

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,9 @@ int executeProduct_(World& dw, std::map<std::string, std::unique_ptr<Tensor<>>>&
744744
(*(tensorC[uuid_D]))[idx_D] = (*(tensorC[uuid_C]))[idx_C];
745745
}
746746
else if(op_C==1){
747-
(*(tensorC[uuid_D]))[idx_D] = (conj_(*(tensorC[uuid_C])))[idx_C];
747+
CTF::Tensor<std::complex<double>> conjT(*(tensorC[uuid_C]));
748+
conjT[idx_C] = CTF::Function<std::complex<double>>([](std::complex<double> a){ return std::conj(a); })((*(tensorC[uuid_C]))[idx_C]);
749+
(*(tensorC[uuid_D]))[idx_D] = conjT[idx_C];
748750
}
749751

750752
//B["i"] = CTF::Function<std::complex<double>>([](std::complex<double> a){ return std::conj(a); })((*(tensorC[uuid_A]))["i"]);
@@ -753,17 +755,27 @@ int executeProduct_(World& dw, std::map<std::string, std::unique_ptr<Tensor<>>>&
753755
tensorC[uuid_D]->contract(alpha, *(tensorC[uuid_A]), idx_A, *(tensorC[uuid_B]), idx_B, beta, idx_D);
754756
}
755757
else if(op_A==1 && op_B==0){
756-
tensorC[uuid_D]->contract(alpha, conj_(*(tensorC[uuid_A])), idx_A, *(tensorC[uuid_B]), idx_B, beta, idx_D);
758+
CTF::Tensor<std::complex<double>> conjT(*(tensorC[uuid_A]));
759+
conjT[idx_A] = CTF::Function<std::complex<double>>([](std::complex<double> a){ return std::conj(a); })((*(tensorC[uuid_A]))[idx_A]);
760+
tensorC[uuid_D]->contract(alpha, conjT, idx_A, *(tensorC[uuid_B]), idx_B, beta, idx_D);
757761
}
758762
else if(op_A==0 && op_B==1){
759-
tensorC[uuid_D]->contract(alpha, *(tensorC[uuid_A]), idx_A, conj_(*(tensorC[uuid_B])), idx_B, beta, idx_D);
763+
CTF::Tensor<std::complex<double>> conjT(*(tensorC[uuid_B]));
764+
conjT[idx_B] = CTF::Function<std::complex<double>>([](std::complex<double> a){ return std::conj(a); })((*(tensorC[uuid_B]))[idx_B]);
765+
tensorC[uuid_D]->contract(alpha, *(tensorC[uuid_A]), idx_A, conjT, idx_B, beta, idx_D);
760766
}
761767
else if(op_A==1 && op_B==1){
762-
tensorC[uuid_D]->contract(alpha, conj_(*(tensorC[uuid_A])), idx_A, conj_(*(tensorC[uuid_B])), idx_B, beta, idx_D);
768+
CTF::Tensor<std::complex<double>> conjT(*(tensorC[uuid_A]));
769+
conjT[idx_A] = CTF::Function<std::complex<double>>([](std::complex<double> a){ return std::conj(a); })((*(tensorC[uuid_A]))[idx_A]);
770+
CTF::Tensor<std::complex<double>> conjT2(*(tensorC[uuid_B]));
771+
conjT2[idx_B] = CTF::Function<std::complex<double>>([](std::complex<double> a){ return std::conj(a); })((*(tensorC[uuid_B]))[idx_B]);
772+
tensorC[uuid_D]->contract(alpha, conjT, idx_A, conjT2, idx_B, beta, idx_D);
763773
}
764774

765775
if(op_D==1){
766-
(*(tensorC[uuid_D]))[idx_D] = (conj_(*(tensorC[uuid_D])))[idx_D];
776+
CTF::Tensor<std::complex<double>> conjT(*(tensorC[uuid_D]));
777+
conjT[idx_D] = CTF::Function<std::complex<double>>([](std::complex<double> a){ return std::conj(a); })((*(tensorC[uuid_D]))[idx_D]);
778+
(*(tensorC[uuid_D]))[idx_D] = conjT[idx_D];
767779
}
768780

769781

0 commit comments

Comments
 (0)