17
17
* and Technology (RIST). All rights reserved.
18
18
* Copyright (c) 2018 Siberian State University of Telecommunications
19
19
* and Information Sciences. All rights reserved.
20
+ * Copyright (c) 2022 IBM Corporation. All rights reserved.
20
21
* $COPYRIGHT$
21
22
*
22
23
* Additional copyrights may follow
@@ -58,7 +59,8 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
58
59
struct ompi_communicator_t * comm ,
59
60
mca_coll_base_module_t * module )
60
61
{
61
- int rank , size , count , err = OMPI_SUCCESS ;
62
+ int rank , size , err = OMPI_SUCCESS ;
63
+ size_t count ;
62
64
ptrdiff_t gap , span ;
63
65
char * recv_buf = NULL , * recv_buf_free = NULL ;
64
66
@@ -67,7 +69,7 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
67
69
size = ompi_comm_size (comm );
68
70
69
71
/* short cut the trivial case */
70
- count = rcount * size ;
72
+ count = rcount * ( size_t ) size ;
71
73
if (0 == count ) {
72
74
return OMPI_SUCCESS ;
73
75
}
@@ -91,17 +93,38 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
91
93
recv_buf = recv_buf_free - gap ;
92
94
}
93
95
94
- /* reduction */
95
- err =
96
- comm -> c_coll -> coll_reduce (sbuf , recv_buf , count , dtype , op , 0 ,
97
- comm , comm -> c_coll -> coll_reduce_module );
96
+ if ( OPAL_UNLIKELY (count > INT_MAX ) ) {
97
+ // Sending the message in the coll_reduce as "rcount*size" would exceed
98
+ // the 'int count' parameter in the coll_reduce() function. Instead reduce
99
+ // the result in "rcount" chunks.
100
+ int i ;
101
+ void * rbuf_ptr , * sbuf_ptr ;
102
+ span = opal_datatype_span (& dtype -> super , rcount , & gap );
103
+ for ( i = 0 ; i < size ; ++ i ) {
104
+ rbuf_ptr = (char * )recv_buf + span * (size_t )i ;
105
+ sbuf_ptr = (char * )sbuf + span * (size_t )i ;
106
+ /* reduction */
107
+ err =
108
+ comm -> c_coll -> coll_reduce (sbuf_ptr , rbuf_ptr , rcount , dtype , op , 0 ,
109
+ comm , comm -> c_coll -> coll_reduce_module );
110
+ if (MPI_SUCCESS != err ) {
111
+ goto cleanup ;
112
+ }
113
+ }
114
+ } else {
115
+ /* reduction */
116
+ err =
117
+ comm -> c_coll -> coll_reduce (sbuf , recv_buf , (int )count , dtype , op , 0 ,
118
+ comm , comm -> c_coll -> coll_reduce_module );
119
+ if (MPI_SUCCESS != err ) {
120
+ goto cleanup ;
121
+ }
122
+ }
98
123
99
124
/* scatter */
100
- if (MPI_SUCCESS == err ) {
101
- err = comm -> c_coll -> coll_scatter (recv_buf , rcount , dtype ,
102
- rbuf , rcount , dtype , 0 ,
103
- comm , comm -> c_coll -> coll_scatter_module );
104
- }
125
+ err = comm -> c_coll -> coll_scatter (recv_buf , rcount , dtype ,
126
+ rbuf , rcount , dtype , 0 ,
127
+ comm , comm -> c_coll -> coll_scatter_module );
105
128
106
129
cleanup :
107
130
if (NULL != recv_buf_free ) free (recv_buf_free );
@@ -146,7 +169,16 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(
146
169
if (comm_size < 2 )
147
170
return MPI_SUCCESS ;
148
171
149
- totalcount = comm_size * rcount ;
172
+ totalcount = comm_size * (size_t )rcount ;
173
+ if ( OPAL_UNLIKELY (totalcount > INT_MAX ) ) {
174
+ /*
175
+ * Large payload collectives are not supported by this algorithm.
176
+ * The blocklens and displs calculations in the loop below
177
+ * will overflow an int data type.
178
+ * Fallback to the linear algorithm.
179
+ */
180
+ return ompi_coll_base_reduce_scatter_block_basic_linear (sbuf , rbuf , rcount , dtype , op , comm , module );
181
+ }
150
182
ompi_datatype_type_extent (dtype , & extent );
151
183
span = opal_datatype_span (& dtype -> super , totalcount , & gap );
152
184
tmpbuf_raw = malloc (span );
@@ -347,7 +379,8 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
347
379
return ompi_coll_base_reduce_scatter_block_basic_linear (sbuf , rbuf , rcount , dtype ,
348
380
op , comm , module );
349
381
}
350
- totalcount = comm_size * rcount ;
382
+
383
+ totalcount = comm_size * (size_t )rcount ;
351
384
ompi_datatype_type_extent (dtype , & extent );
352
385
span = opal_datatype_span (& dtype -> super , totalcount , & gap );
353
386
tmpbuf_raw = malloc (span );
@@ -431,22 +464,22 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
431
464
* have their result calculated by the process to their
432
465
* right (rank + 1).
433
466
*/
434
- int send_count = 0 , recv_count = 0 ;
467
+ size_t send_count = 0 , recv_count = 0 ;
435
468
if (vrank < vpeer ) {
436
469
/* Send the right half of the buffer, recv the left half */
437
470
send_index = recv_index + mask ;
438
- send_count = rcount * ompi_range_sum (send_index , last_index - 1 , nprocs_rem - 1 );
439
- recv_count = rcount * ompi_range_sum (recv_index , send_index - 1 , nprocs_rem - 1 );
471
+ send_count = rcount * ( size_t ) ompi_range_sum (send_index , last_index - 1 , nprocs_rem - 1 );
472
+ recv_count = rcount * ( size_t ) ompi_range_sum (recv_index , send_index - 1 , nprocs_rem - 1 );
440
473
} else {
441
474
/* Send the left half of the buffer, recv the right half */
442
475
recv_index = send_index + mask ;
443
- send_count = rcount * ompi_range_sum (send_index , recv_index - 1 , nprocs_rem - 1 );
444
- recv_count = rcount * ompi_range_sum (recv_index , last_index - 1 , nprocs_rem - 1 );
476
+ send_count = rcount * ( size_t ) ompi_range_sum (send_index , recv_index - 1 , nprocs_rem - 1 );
477
+ recv_count = rcount * ( size_t ) ompi_range_sum (recv_index , last_index - 1 , nprocs_rem - 1 );
445
478
}
446
- ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1 ) ?
447
- 2 * recv_index : nprocs_rem + recv_index );
448
- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
449
- 2 * send_index : nprocs_rem + send_index );
479
+ ptrdiff_t rdispl = rcount * (size_t )( (recv_index <= nprocs_rem - 1 ) ?
480
+ 2 * recv_index : nprocs_rem + recv_index );
481
+ ptrdiff_t sdispl = rcount * (size_t )( (send_index <= nprocs_rem - 1 ) ?
482
+ 2 * send_index : nprocs_rem + send_index );
450
483
struct ompi_request_t * request = NULL ;
451
484
452
485
if (recv_count > 0 ) {
@@ -587,7 +620,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
587
620
sbuf , rbuf , rcount , dtype , op , comm , module );
588
621
}
589
622
590
- totalcount = comm_size * rcount ;
623
+ totalcount = comm_size * ( size_t ) rcount ;
591
624
ompi_datatype_type_extent (dtype , & extent );
592
625
span = opal_datatype_span (& dtype -> super , totalcount , & gap );
593
626
tmpbuf [0 ] = malloc (span );
@@ -677,13 +710,17 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
677
710
/* Send the upper half of reduction buffer, recv the lower half */
678
711
recv_index += nblocks ;
679
712
}
680
- int send_count = rcount * ompi_range_sum (send_index ,
681
- send_index + nblocks - 1 , nprocs_rem - 1 );
682
- int recv_count = rcount * ompi_range_sum (recv_index ,
683
- recv_index + nblocks - 1 , nprocs_rem - 1 );
684
- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
713
+ size_t send_count = rcount *
714
+ (size_t )ompi_range_sum (send_index ,
715
+ send_index + nblocks - 1 ,
716
+ nprocs_rem - 1 );
717
+ size_t recv_count = rcount *
718
+ (size_t )ompi_range_sum (recv_index ,
719
+ recv_index + nblocks - 1 ,
720
+ nprocs_rem - 1 );
721
+ ptrdiff_t sdispl = rcount * (size_t )((send_index <= nprocs_rem - 1 ) ?
685
722
2 * send_index : nprocs_rem + send_index );
686
- ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1 ) ?
723
+ ptrdiff_t rdispl = rcount * (size_t )( (recv_index <= nprocs_rem - 1 ) ?
687
724
2 * recv_index : nprocs_rem + recv_index );
688
725
689
726
err = ompi_coll_base_sendrecv (psend + (ptrdiff_t )sdispl * extent , send_count ,
@@ -719,7 +756,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
719
756
* Process has two blocks: for excluded process and own.
720
757
* Send result to the excluded process.
721
758
*/
722
- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
759
+ ptrdiff_t sdispl = rcount * (size_t )( (send_index <= nprocs_rem - 1 ) ?
723
760
2 * send_index : nprocs_rem + send_index );
724
761
err = MCA_PML_CALL (send (psend + (ptrdiff_t )sdispl * extent ,
725
762
rcount , dtype , peer - 1 ,
@@ -729,7 +766,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
729
766
}
730
767
731
768
/* Send result to a remote process according to a mirror permutation */
732
- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
769
+ ptrdiff_t sdispl = rcount * (size_t )( (send_index <= nprocs_rem - 1 ) ?
733
770
2 * send_index : nprocs_rem + send_index );
734
771
/* If process has two blocks, then send the second block (own block) */
735
772
if (vpeer < nprocs_rem )
@@ -821,7 +858,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
821
858
if (rcount == 0 || comm_size < 2 )
822
859
return MPI_SUCCESS ;
823
860
824
- totalcount = comm_size * rcount ;
861
+ totalcount = comm_size * ( size_t ) rcount ;
825
862
ompi_datatype_type_extent (dtype , & extent );
826
863
span = opal_datatype_span (& dtype -> super , totalcount , & gap );
827
864
tmpbuf [0 ] = malloc (span );
@@ -843,7 +880,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
843
880
if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
844
881
}
845
882
846
- int nblocks = totalcount , send_index = 0 , recv_index = 0 ;
883
+ size_t nblocks = totalcount , send_index = 0 , recv_index = 0 ;
847
884
for (int mask = 1 ; mask < comm_size ; mask <<= 1 ) {
848
885
int peer = rank ^ mask ;
849
886
nblocks /= 2 ;
0 commit comments