Skip to content

Commit

Permalink
allocate wave-functions from memory pool (electronic-structure#784)
Browse files Browse the repository at this point in the history
  • Loading branch information
toxa81 authored and mtaillefumier committed Nov 25, 2022
1 parent 5c51e6b commit 210ecf4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
6 changes: 3 additions & 3 deletions src/SDDK/wave_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ class Wave_functions_base
num_sc_ = num_spins(2);
}
for (int is = 0; is < num_sc_.get(); is++) {
data_[is] = sddk::mdarray<std::complex<T>, 2>(num_pw_ + num_mt_, num_wf_.get(), default_mem__,
"Wave_functions_base::data_");
data_[is] = sddk::mdarray<std::complex<T>, 2>(num_pw_ + num_mt_, num_wf_.get(),
sddk::get_memory_pool(default_mem__), "Wave_functions_base::data_");
}
}

Expand Down Expand Up @@ -493,7 +493,7 @@ class Wave_functions_base
allocate(sddk::memory_t mem__)
{
for (int s = 0; s < num_sc_.get(); s++) {
data_[s].allocate(get_memory_pool(mem__));
data_[s].allocate(sddk::get_memory_pool(mem__));
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/band/davidson.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ davidson(Hamiltonian_k<T>& Hk__, wf::num_bands num_bands__, wf::num_mag_dims num
{
PROFILE("sirius::davidson");

PROFILE_START("sirius::davidson|init");
auto& ctx = Hk__.H0().ctx();
ctx.print_memory_usage(__FILE__, __LINE__);

Expand Down Expand Up @@ -317,6 +318,7 @@ davidson(Hamiltonian_k<T>& Hk__, wf::num_bands num_bands__, wf::num_mag_dims num
<< " non-collinear : " << nc_mag << std::endl
<< " number of extra phi : " << num_extra_phi << std::endl;
}
PROFILE_STOP("sirius::davidson|init");

PROFILE_START("sirius::davidson|iter");
for (int ispin_step = 0; ispin_step < num_spinors; ispin_step++) {
Expand Down
21 changes: 10 additions & 11 deletions src/context/simulation_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,9 +593,9 @@ Simulation_context::initialize()

/* setup BLACS grid */
if (std_solver.is_parallel()) {
blacs_grid_ = std::unique_ptr<sddk::BLACS_grid>(new sddk::BLACS_grid(comm_band(), npr, npc));
blacs_grid_ = std::make_unique<sddk::BLACS_grid>(comm_band(), npr, npc);
} else {
blacs_grid_ = std::unique_ptr<sddk::BLACS_grid>(new sddk::BLACS_grid(sddk::Communicator::self(), 1, 1));
blacs_grid_ = std::make_unique<sddk::BLACS_grid>(sddk::Communicator::self(), 1, 1);
}

/* setup the cyclic block size */
Expand Down Expand Up @@ -1002,13 +1002,13 @@ Simulation_context::update()
auto spl_z = split_fft_z(fft_coarse_grid_[2], comm_fft_coarse());

/* create spfft buffer for coarse transform */
spfft_grid_coarse_ = std::unique_ptr<spfft::Grid>(new spfft::Grid(
fft_coarse_grid_[0], fft_coarse_grid_[1], fft_coarse_grid_[2], gvec_coarse_fft_->zcol_count_fft(),
spl_z.local_size(), spfft_pu, -1, comm_fft_coarse().mpi_comm(), SPFFT_EXCH_DEFAULT));
spfft_grid_coarse_ = std::make_unique<spfft::Grid>(fft_coarse_grid_[0], fft_coarse_grid_[1],
fft_coarse_grid_[2], gvec_coarse_fft_->zcol_count_fft(),
spl_z.local_size(), spfft_pu, -1, comm_fft_coarse().mpi_comm(), SPFFT_EXCH_DEFAULT);
#ifdef USE_FP32
spfft_grid_coarse_float_ = std::unique_ptr<spfft::GridFloat>(new spfft::GridFloat(
fft_coarse_grid_[0], fft_coarse_grid_[1], fft_coarse_grid_[2], gvec_coarse_fft_->zcol_count_fft(),
spl_z.local_size(), spfft_pu, -1, comm_fft_coarse().mpi_comm(), SPFFT_EXCH_DEFAULT));
spfft_grid_coarse_float_ = std::make_unique<spfft::GridFloat>(fft_coarse_grid_[0], fft_coarse_grid_[1],
fft_coarse_grid_[2], gvec_coarse_fft_->zcol_count_fft(), spl_z.local_size(), spfft_pu, -1,
comm_fft_coarse().mpi_comm(), SPFFT_EXCH_DEFAULT);
#endif
/* create spfft transformations */
const auto fft_type_coarse = gvec_coarse().reduced() ? SPFFT_TRANS_R2C : SPFFT_TRANS_C2C;
Expand Down Expand Up @@ -1289,8 +1289,7 @@ Simulation_context::update()
}
for (int iat = 0; iat < unit_cell().num_atom_types(); iat++) {
if (unit_cell().atom_type(iat).augment() && unit_cell().atom_type(iat).num_atoms() > 0) {
augmentation_op_[iat] = std::unique_ptr<Augmentation_operator>(
new Augmentation_operator(unit_cell().atom_type(iat), gvec()));
augmentation_op_[iat] = std::make_unique<Augmentation_operator>(unit_cell().atom_type(iat), gvec());
augmentation_op_[iat]->generate_pw_coeffs(aug_ri(), gvec_tp_, *mp, mpd);
} else {
augmentation_op_[iat] = nullptr;
Expand Down Expand Up @@ -1589,7 +1588,7 @@ Simulation_context::init_comm()
/* here we know the number of ranks for band parallelization */

/* if we have multiple ranks per node and band parallelization, switch to parallel FFT for coarse mesh */
if (sddk::num_ranks_per_node() > 1 && comm_band().size() > 1) {
if ((npr == npb) || (sddk::num_ranks_per_node() > acc::num_devices() && comm_band().size() > 1)) {
cfg().control().fft_mode("parallel");
}

Expand Down

0 comments on commit 210ecf4

Please sign in to comment.