Skip to content

Commit

Permalink
MultiDiracDeterminant.2.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
anbenali committed Mar 28, 2022
1 parent ebc3332 commit 7818a7e
Showing 1 changed file with 76 additions and 67 deletions.
143 changes: 76 additions & 67 deletions src/QMCWaveFunctions/Fermion/MultiDiracDeterminant.2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,9 @@ void MultiDiracDeterminant::mw_evaluateDetsForPtclMove(const RefVectorWithLeader
det_leader.RatioTimer.start();

OffloadVector<ValueType> det0_list(nw, 1.0);

constexpr ValueType czero(0);
UnpinnedOffloadVector<ValueType> czero_vec(nw, czero);

ValueType* czero_ptr = czero_vec.data();
int dummy_handle = 0;

OffloadVector<ValueType> curRatio_list(nw, 0.0);
OffloadVector<size_t> confgListOccup(det_leader.NumPtcls,0.0);

RefVector<OffloadVector<ValueType>> workV1_list, workV2_list, workV3_list;
RefVector<OffloadVector<ValueType>> psiV_list, psiV_temp_list, new_ratios_to_ref_list;
RefVector<OffloadMatrix<ValueType>> TpsiM_list, psiM_list, dotProducts_list;
Expand Down Expand Up @@ -387,69 +382,92 @@ void MultiDiracDeterminant::mw_evaluateDetsForPtclMove(const RefVectorWithLeader
MultiDiracDeterminant& det = (det_list[iw]);
Vector<ValueType> psiV_list_host_view(psiV_list[iw].get().data(), psiV_list[iw].get().size());
det.getPhi()->evaluateValue(P_list[iw], iat, psiV_list_host_view);
///Transfer of data from host to Device
psiV_list[iw].get().updateTo();
}


det_leader.evalOrbTimer.stop();

det_leader.ExtraStuffTimer.start();
for (size_t iw = 0; iw < nw; iw++)
{
MultiDiracDeterminant& det = (det_list[iw]);
const auto& confgList = *det.ciConfigList;
auto it(confgList[det.ReferenceDeterminant].occup.begin());
psiMinv_temp_list[iw].get() = psiMinv_list[iw].get();
for (size_t i = 0; i < det_leader.NumPtcls; i++)
psiV_temp_list[iw].get()[i] = psiV_list[iw].get()[*(it++)];
for (size_t i = 0; i < det_leader.NumOrbitals; i++)
TpsiM_list[iw].get()(i, WorkingIndex) = psiV_list[iw].get()[i];

///Pin psiV_temp and psiMinv and Tpsi_list to device
psiV_temp_list[iw].get().updateTo();
psiMinv_temp_list[iw].get().updateTo();
TpsiM_list[iw].get().updateTo();
}

const ValueType cone(1);
size_t success = 0;
int dummy_handle=0;
const auto psiMinv_rows = psiMinv_list[0].get().rows();
const auto psiMinv_cols = psiMinv_list[0].get().cols();
const int xmax = psiMinv_cols * psiMinv_cols;
const auto TpsiM_cols = TpsiM_list[0].get().cols();
const auto psiM_cols = psiM_list[0].get().cols();
const auto TpsiM_rows = TpsiM_list[0].get().rows();
const auto NumPtcls = det_leader.NumPtcls;
const auto NumOrbitals = det_leader.NumOrbitals;
const auto& confgList = *det_leader.ciConfigList;
for (size_t i = 0; i < det_leader.NumPtcls; i++)
confgListOccup[i]=confgList[det_leader.ReferenceDeterminant].occup[i];

OffloadVector<ValueType*> psiV_deviceptr_list(nw);
OffloadVector<ValueType*> psiV_temp_deviceptr_list(nw);

OffloadVector<ValueType*> psiMinv_deviceptr_list(nw);
OffloadVector<ValueType*> psiMinv_temp_deviceptr_list(nw);
OffloadVector<ValueType*> workV1_deviceptr_list(nw);
OffloadVector<ValueType*> workV2_deviceptr_list(nw);

OffloadVector<ValueType*> psiV_temp_hostptr_list(nw);
OffloadVector<ValueType*> psiMinv_temp_hostptr_list(nw);
OffloadVector<ValueType*> TpsiM_deviceptr_list(nw);
OffloadVector<ValueType*> psiM_deviceptr_list(nw);

for (size_t iw = 0; iw < nw; iw++)
{
psiV_deviceptr_list[iw] = psiV_list[iw].get().device_data();
psiV_temp_deviceptr_list[iw] = psiV_temp_list[iw].get().device_data();
psiMinv_temp_deviceptr_list[iw] = psiMinv_temp_list[iw].get().device_data();
workV1_deviceptr_list[iw] = workV1_list[iw].get().data();
workV2_deviceptr_list[iw] = workV2_list[iw].get().data();

psiMinv_deviceptr_list[iw] = psiMinv_list[iw].get().device_data();
psiMinv_temp_deviceptr_list[iw] = psiMinv_temp_list[iw].get().device_data();

psiV_temp_hostptr_list[iw] = psiV_temp_list[iw].get().data();
psiMinv_temp_hostptr_list[iw] = psiMinv_temp_list[iw].get().data();
TpsiM_deviceptr_list[iw] = TpsiM_list[iw].get().device_data();
psiM_deviceptr_list[iw] = psiM_list[iw].get().device_data();
}

auto* psiV_temp_list_Hptr = psiV_temp_hostptr_list.data();
auto* psiMinv_temp_list_Hptr = psiMinv_temp_hostptr_list.data();


auto* psiV_list_ptr = psiV_deviceptr_list.data();
auto* psiV_temp_list_ptr = psiV_temp_deviceptr_list.data();
auto* psiMinv_temp_list_ptr = psiMinv_temp_deviceptr_list.data();
auto* workV1_list_ptr = workV1_deviceptr_list.data();
auto* workV2_list_ptr = workV2_deviceptr_list.data();

auto* psiMinv_list_ptr = psiMinv_deviceptr_list.device_data();
auto* psiMinv_temp_list_ptr = psiMinv_temp_deviceptr_list.device_data();

auto* TpsiM_list_ptr = TpsiM_deviceptr_list.device_data();
auto* psiM_list_ptr = psiM_deviceptr_list.data();

auto* curRatio_list_ptr = curRatio_list.data();
auto* confgListOccup_ptr = confgListOccup.data();

det_leader.ExtraStuffTimer.start();

psiMinv_deviceptr_list.updateTo();
psiMinv_temp_deviceptr_list.updateTo();

confgListOccup.updateTo();
det0_list.updateTo();

success=ompBLAS::copy_batched(dummy_handle, psiMinv_rows*psiMinv_cols, psiMinv_list_ptr,1,psiMinv_temp_list_ptr, 1, nw) ;
if (success != 0)
throw std::runtime_error("In MultiDiracDeterminant ompBLAS::copy_batched_offset failed.");

TpsiM_deviceptr_list.updateTo();
psiV_deviceptr_list.updateTo();
success=ompBLAS::copy_batched_offset(dummy_handle, det_leader.NumOrbitals, psiV_list_ptr, 0, 1, TpsiM_list_ptr, WorkingIndex, TpsiM_cols, nw);
if (success != 0)
throw std::runtime_error("In MultiDiracDeterminant ompBLAS::copy_batched_offset failed.");



PRAGMA_OFFLOAD("omp target teams distribute map(always,from:curRatio_list_ptr[:nw]) \
map(always, to: psiV_temp_list_ptr[:nw], psiMinv_temp_list_ptr[:nw])")
map(always, to: psiV_temp_list_ptr[:nw]) \
is_device_ptr(psiV_list_ptr) \
is_device_ptr(psiMinv_temp_list_ptr)")
for (size_t iw = 0; iw < nw; iw++)
{
for (size_t i = 0; i < NumPtcls; i++)
{
const size_t J=confgListOccup_ptr[i];
psiV_temp_list_ptr[iw][i] = psiV_list_ptr[iw][J];
}

ValueType c_ratio = 0.0;
PRAGMA_OFFLOAD("omp parallel for reduction(+ : c_ratio)")
for (size_t jc = 0; jc < psiMinv_cols; jc += 1)
Expand All @@ -471,13 +489,18 @@ void MultiDiracDeterminant::mw_evaluateDetsForPtclMove(const RefVectorWithLeader
TpsiM_list, *det_leader.detData, *det_leader.uniquePairs,
*det_leader.DetSigns, dotProducts_list, new_ratios_to_ref_list);

PRAGMA_OFFLOAD("omp target teams distribute parallel for collapse(2) is_device_ptr(TpsiM_list_ptr) \
map(always, to:psiM_list_ptr[:nw])")
for (size_t iw = 0; iw < nw; iw++)
for (size_t i = 0; i < NumOrbitals; i++)
TpsiM_list_ptr[iw][i * TpsiM_cols + WorkingIndex] = psiM_list_ptr[iw][i * psiM_cols + WorkingIndex];

det_leader.ExtraStuffTimer.stop();
for (size_t iw = 0; iw < nw; iw++)
{
TpsiM_list[iw].get().updateFrom();
MultiDiracDeterminant& det = (det_list[iw]);
det.curRatio = curRatio_list_ptr[iw];
for (size_t i = 0; i < det_leader.NumOrbitals; i++)
TpsiM_list[iw].get()(i, WorkingIndex) = psiM_list[iw].get()(WorkingIndex, i);
}

det_leader.RatioTimer.stop();
Expand Down Expand Up @@ -893,7 +916,6 @@ void MultiDiracDeterminant::mw_evaluateDetsAndGradsForPtclMove(

for (size_t iw = 0; iw < nw; iw++)
{
///@YE_LUO:Not sure if this transfer is needed.
TpsiM_list[iw].get().updateFrom();
MultiDiracDeterminant& det = (det_list[iw]);
det.curRatio = curRatio_list[iw];
Expand Down Expand Up @@ -1114,21 +1136,11 @@ void MultiDiracDeterminant::mw_evaluateGrads(const RefVectorWithLeader<MultiDira
dpsiMinv_list);


/* PRAGMA_OFFLOAD("omp target teams distribute parallel for collapse(2) map(to:dpsiM_list_ptr[:nw]) \
PRAGMA_OFFLOAD("omp target teams distribute parallel for collapse(2) map(to:dpsiM_list_ptr[:nw]) \
map(always,to: TpsiM_list_ptr[:nw])")
for (size_t iw = 0; iw < nw; iw++)
for (size_t i = 0; i < NumOrbitals; i++)
TpsiM_list_ptr[iw][i*TpsiM_cols+ WorkingIndex] = dpsiM_list_ptr[iw][dpsiM_cols*WorkingIndex+i][idim];
*/

for (size_t iw = 0; iw < nw; iw++)
{
for (size_t i = 0; i < det_leader.NumOrbitals; i++)
{
TpsiM_list[iw].get()(i, WorkingIndex) = dpsiM_list[iw].get()(WorkingIndex, i)[idim];
}
TpsiM_list[iw].get().updateTo();
}

det_leader.mw_BuildDotProductsAndCalculateRatiosGrads(nw, det_leader.ReferenceDeterminant, WorkingIndex, idim,
det_leader.getNumDets(), ratioG_list, dpsiMinv_list,
Expand All @@ -1137,18 +1149,15 @@ void MultiDiracDeterminant::mw_evaluateGrads(const RefVectorWithLeader<MultiDira
grads_list, Grads);


// PRAGMA_OFFLOAD("omp target teams distribute parallel for map(from:TpsiM_list_ptr[:nw]) \
// map(always,to:psiM_list_ptr[:nw])")
// for (size_t iw = 0; iw < nw; iw++)
// for (size_t i = 0; i < NumOrbitals; i++)
// TpsiM_list_ptr[iw][i*TpsiM_cols+ WorkingIndex] = psiM_list_ptr[iw][i+psiM_cols*WorkingIndex];
PRAGMA_OFFLOAD("omp target teams distribute parallel for map(from:TpsiM_list_ptr[:nw]) \
map(always,to:psiM_list_ptr[:nw])")
for (size_t iw = 0; iw < nw; iw++)
for (size_t i = 0; i < NumOrbitals; i++)
TpsiM_list_ptr[iw][i*TpsiM_cols+ WorkingIndex] = psiM_list_ptr[iw][i+psiM_cols*WorkingIndex];

for (size_t iw = 0; iw < nw; iw++)
for (size_t i = 0; i < det_leader.NumOrbitals; i++)
TpsiM_list[iw].get()(i, WorkingIndex) = psiM_list[iw].get()(WorkingIndex, i);
}
//for (size_t iw = 0; iw < nw; iw++)
// TpsiM_list[iw].get().updateFrom();
for (size_t iw = 0; iw < nw; iw++)
TpsiM_list[iw].get().updateFrom();
}

void MultiDiracDeterminant::mw_updateRatios_generic(int ext_level,
Expand Down

0 comments on commit 7818a7e

Please sign in to comment.