Skip to content

Commit 61120c3

Browse files
committed
Fix base reduce_scatter_block for large payloads
* Update the four base reduce_scatter_block algorithms to support large payload collectives. - The recursive doubling collective fix would have required changing some ompi_datatype functions which was more extensive than I wanted to go after in this commit. So if a large payload is expected in that collective then it falls back to the linear algorithm. Signed-off-by: Joshua Hursey <jhursey@us.ibm.com>
1 parent 34685a2 commit 61120c3

File tree

1 file changed

+70
-33
lines changed

1 file changed

+70
-33
lines changed

ompi/mca/coll/base/coll_base_reduce_scatter_block.c

Lines changed: 70 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* and Technology (RIST). All rights reserved.
1818
* Copyright (c) 2018 Siberian State University of Telecommunications
1919
* and Information Sciences. All rights reserved.
20+
* Copyright (c) 2022 IBM Corporation. All rights reserved.
2021
* $COPYRIGHT$
2122
*
2223
* Additional copyrights may follow
@@ -58,7 +59,8 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
5859
struct ompi_communicator_t *comm,
5960
mca_coll_base_module_t *module)
6061
{
61-
int rank, size, count, err = OMPI_SUCCESS;
62+
int rank, size, err = OMPI_SUCCESS;
63+
size_t count;
6264
ptrdiff_t gap, span;
6365
char *recv_buf = NULL, *recv_buf_free = NULL;
6466

@@ -67,7 +69,7 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
6769
size = ompi_comm_size(comm);
6870

6971
/* short cut the trivial case */
70-
count = rcount * size;
72+
count = rcount * (size_t)size;
7173
if (0 == count) {
7274
return OMPI_SUCCESS;
7375
}
@@ -91,17 +93,38 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
9193
recv_buf = recv_buf_free - gap;
9294
}
9395

94-
/* reduction */
95-
err =
96-
comm->c_coll->coll_reduce(sbuf, recv_buf, count, dtype, op, 0,
97-
comm, comm->c_coll->coll_reduce_module);
96+
if ( OPAL_UNLIKELY(count > INT_MAX) ) {
97+
// Sending the message in the coll_reduce as "rcount*size" would exceed
98+
// the 'int count' parameter in the coll_reduce() function. Instead reduce
99+
// the result in "rcount" chunks.
100+
int i;
101+
void *rbuf_ptr, *sbuf_ptr;
102+
span = opal_datatype_span(&dtype->super, rcount, &gap);
103+
for( i = 0; i < size; ++i ) {
104+
rbuf_ptr = (char*)recv_buf + span * (size_t)i;
105+
sbuf_ptr = (char*)sbuf + span * (size_t)i;
106+
/* reduction */
107+
err =
108+
comm->c_coll->coll_reduce(sbuf_ptr, rbuf_ptr, rcount, dtype, op, 0,
109+
comm, comm->c_coll->coll_reduce_module);
110+
if (MPI_SUCCESS != err) {
111+
goto cleanup;
112+
}
113+
}
114+
} else {
115+
/* reduction */
116+
err =
117+
comm->c_coll->coll_reduce(sbuf, recv_buf, (int)count, dtype, op, 0,
118+
comm, comm->c_coll->coll_reduce_module);
119+
if (MPI_SUCCESS != err) {
120+
goto cleanup;
121+
}
122+
}
98123

99124
/* scatter */
100-
if (MPI_SUCCESS == err) {
101-
err = comm->c_coll->coll_scatter(recv_buf, rcount, dtype,
102-
rbuf, rcount, dtype, 0,
103-
comm, comm->c_coll->coll_scatter_module);
104-
}
125+
err = comm->c_coll->coll_scatter(recv_buf, rcount, dtype,
126+
rbuf, rcount, dtype, 0,
127+
comm, comm->c_coll->coll_scatter_module);
105128

106129
cleanup:
107130
if (NULL != recv_buf_free) free(recv_buf_free);
@@ -146,7 +169,16 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(
146169
if (comm_size < 2)
147170
return MPI_SUCCESS;
148171

149-
totalcount = comm_size * rcount;
172+
totalcount = comm_size * (size_t)rcount;
173+
if( OPAL_UNLIKELY(totalcount > INT_MAX) ) {
174+
/*
175+
* Large payload collectives are not supported by this algorithm.
176+
* The blocklens and displs calculations in the loop below
177+
* will overflow an int data type.
178+
* Fallback to the linear algorithm.
179+
*/
180+
return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype, op, comm, module);
181+
}
150182
ompi_datatype_type_extent(dtype, &extent);
151183
span = opal_datatype_span(&dtype->super, totalcount, &gap);
152184
tmpbuf_raw = malloc(span);
@@ -347,7 +379,8 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
347379
return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype,
348380
op, comm, module);
349381
}
350-
totalcount = comm_size * rcount;
382+
383+
totalcount = comm_size * (size_t)rcount;
351384
ompi_datatype_type_extent(dtype, &extent);
352385
span = opal_datatype_span(&dtype->super, totalcount, &gap);
353386
tmpbuf_raw = malloc(span);
@@ -431,22 +464,22 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
431464
* have their result calculated by the process to their
432465
* right (rank + 1).
433466
*/
434-
int send_count = 0, recv_count = 0;
467+
size_t send_count = 0, recv_count = 0;
435468
if (vrank < vpeer) {
436469
/* Send the right half of the buffer, recv the left half */
437470
send_index = recv_index + mask;
438-
send_count = rcount * ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
439-
recv_count = rcount * ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
471+
send_count = rcount * (size_t)ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
472+
recv_count = rcount * (size_t)ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
440473
} else {
441474
/* Send the left half of the buffer, recv the right half */
442475
recv_index = send_index + mask;
443-
send_count = rcount * ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
444-
recv_count = rcount * ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
476+
send_count = rcount * (size_t)ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
477+
recv_count = rcount * (size_t)ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
445478
}
446-
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
447-
2 * recv_index : nprocs_rem + recv_index);
448-
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
449-
2 * send_index : nprocs_rem + send_index);
479+
ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ?
480+
2 * recv_index : nprocs_rem + recv_index);
481+
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
482+
2 * send_index : nprocs_rem + send_index);
450483
struct ompi_request_t *request = NULL;
451484

452485
if (recv_count > 0) {
@@ -587,7 +620,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
587620
sbuf, rbuf, rcount, dtype, op, comm, module);
588621
}
589622

590-
totalcount = comm_size * rcount;
623+
totalcount = comm_size * (size_t)rcount;
591624
ompi_datatype_type_extent(dtype, &extent);
592625
span = opal_datatype_span(&dtype->super, totalcount, &gap);
593626
tmpbuf[0] = malloc(span);
@@ -677,13 +710,17 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
677710
/* Send the upper half of reduction buffer, recv the lower half */
678711
recv_index += nblocks;
679712
}
680-
int send_count = rcount * ompi_range_sum(send_index,
681-
send_index + nblocks - 1, nprocs_rem - 1);
682-
int recv_count = rcount * ompi_range_sum(recv_index,
683-
recv_index + nblocks - 1, nprocs_rem - 1);
684-
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
713+
size_t send_count = rcount *
714+
(size_t)ompi_range_sum(send_index,
715+
send_index + nblocks - 1,
716+
nprocs_rem - 1);
717+
size_t recv_count = rcount *
718+
(size_t)ompi_range_sum(recv_index,
719+
recv_index + nblocks - 1,
720+
nprocs_rem - 1);
721+
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
685722
2 * send_index : nprocs_rem + send_index);
686-
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
723+
ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ?
687724
2 * recv_index : nprocs_rem + recv_index);
688725

689726
err = ompi_coll_base_sendrecv(psend + (ptrdiff_t)sdispl * extent, send_count,
@@ -719,7 +756,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
719756
* Process has two blocks: for excluded process and own.
720757
* Send result to the excluded process.
721758
*/
722-
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
759+
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
723760
2 * send_index : nprocs_rem + send_index);
724761
err = MCA_PML_CALL(send(psend + (ptrdiff_t)sdispl * extent,
725762
rcount, dtype, peer - 1,
@@ -729,7 +766,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
729766
}
730767

731768
/* Send result to a remote process according to a mirror permutation */
732-
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
769+
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
733770
2 * send_index : nprocs_rem + send_index);
734771
/* If process has two blocks, then send the second block (own block) */
735772
if (vpeer < nprocs_rem)
@@ -821,7 +858,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
821858
if (rcount == 0 || comm_size < 2)
822859
return MPI_SUCCESS;
823860

824-
totalcount = comm_size * rcount;
861+
totalcount = comm_size * (size_t)rcount;
825862
ompi_datatype_type_extent(dtype, &extent);
826863
span = opal_datatype_span(&dtype->super, totalcount, &gap);
827864
tmpbuf[0] = malloc(span);
@@ -843,7 +880,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
843880
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
844881
}
845882

846-
int nblocks = totalcount, send_index = 0, recv_index = 0;
883+
size_t nblocks = totalcount, send_index = 0, recv_index = 0;
847884
for (int mask = 1; mask < comm_size; mask <<= 1) {
848885
int peer = rank ^ mask;
849886
nblocks /= 2;

0 commit comments

Comments
 (0)