@@ -121,13 +121,15 @@ void generate_alltoall_update_data(
121
121
std::vector<RepartDistMatrix::all_to_all_data> &update_data)
122
122
{
123
123
label linop_offset_store{0 };
124
- for (size_t i = 0 ; i < 3 ; i++) {
124
+ // NOTE in case of symmetric matrix 0 (upper) is same as 1 (lower)
125
+ // thus we can start at 1
126
+ label start = 0 ;
127
+ for (size_t i = start; i < 3 ; i++) {
125
128
label interface_size = in->get_rows ()[i].size ();
126
129
label linop_idx = (fuse) ? 0 : in->get_id ()[i];
127
130
label linop_offset = (fuse) ? linop_offset_store : 0 ;
128
131
auto comm_pattern = compute_gather_to_owner_counts (
129
132
exec_handler, ranks_per_owner, interface_size);
130
-
131
133
size_t recv_size = comm_pattern.recv_offsets .back ();
132
134
133
135
// NOTE Probably dont need to store linops[linop-idx] because we can
@@ -395,6 +397,9 @@ void update_impl(
395
397
auto all_to_all_update = [repart_comm, ref_exec, device_exec,
396
398
all_to_all_update_data, host_A, force_host_buffer,
397
399
exec_handler, rank]() {
400
+ // NOTE if symmetric (get it from host_A) we can skip id=0 and wait till
401
+ // id=1 has been copied to use device copy
402
+ //
398
403
for (auto [id, comm_pattern, data_ptr] : all_to_all_update_data) {
399
404
// auto start = std::chrono::steady_clock::now();
400
405
auto repartAllToAll =
@@ -409,24 +414,43 @@ void update_impl(
409
414
// communicate_values(ref_exec, device_exec, repart_comm,
410
415
// repartAllToAll,
411
416
// send_data_ptr, data_ptr, force_host_buffer);
412
- // if ( repart_comm->rank() == 0 ) {
413
417
// std::cout << __FILE__ <<
414
418
// " Pstream::rank " << Pstream::myProcNo() <<
415
419
// " repart_rank() " << repart_comm->rank() <<
416
420
// " send_offsets.back() " <<
421
+ // " id " << id <<
417
422
// repartAllToAll.send_offsets.back() << " recv_counts: " <<
418
423
// repartAllToAll.recv_counts << " recv_offsets: " <<
419
424
// repartAllToAll.recv_offsets <<
420
425
// std::endl;
421
- // }
422
- MPI_Request request;
423
-
424
- MPI_Igatherv (send_data_ptr, repartAllToAll.send_offsets .back (),
425
- MPI_DOUBLE, data_ptr,
426
- repartAllToAll.recv_counts .data (),
427
- repartAllToAll.recv_offsets .data (), MPI_DOUBLE, 0 ,
428
- repart_comm->get (), &request);
429
- MPI_Wait (&request, MPI_STATUS_IGNORE);
426
+
427
+ if (id == 0 && host_A->get_symmetric ()) {
428
+ } else {
429
+ MPI_Request request;
430
+ MPI_Igatherv (send_data_ptr, repartAllToAll.send_offsets .back (),
431
+ MPI_DOUBLE, data_ptr,
432
+ repartAllToAll.recv_counts .data (),
433
+ repartAllToAll.recv_offsets .data (), MPI_DOUBLE, 0 ,
434
+ repart_comm->get (), &request);
435
+ MPI_Wait (&request, MPI_STATUS_IGNORE);
436
+ }
437
+
438
+ // Perform symmetric inter device copy
439
+ if (id == 1 && repart_comm->rank () == 0 &&
440
+ host_A->get_symmetric ()) {
441
+ auto [zid, zcomm_pattern, zdata_ptr] =
442
+ all_to_all_update_data[0 ];
443
+ // copy recv size data from data_ptr to zdata_ptr
444
+ //
445
+ label recv_buffer_size = repartAllToAll.recv_offsets .back ();
446
+ auto l_view = gko::array<scalar>::view (
447
+ device_exec, recv_buffer_size, data_ptr);
448
+
449
+ auto u_view = gko::array<scalar>::view (
450
+ device_exec, recv_buffer_size, zdata_ptr);
451
+
452
+ u_view = l_view;
453
+ }
430
454
}
431
455
};
432
456
0 commit comments