Skip to content

libnbc: Fix int overflow when handling count parameters #9616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions ompi/mca/coll/libnbc/nbc.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* Author(s): Torsten Hoefler <htor@cs.indiana.edu>
*
* Copyright (c) 2012 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2016 IBM Corporation. All rights reserved.
* Copyright (c) 2016-2021 IBM Corporation. All rights reserved.
* Copyright (c) 2017 Ian Bradley Morgan and Anthony Skjellum. All
* rights reserved.
* Copyright (c) 2018 FUJITSU LIMITED. All rights reserved.
Expand Down Expand Up @@ -119,7 +119,7 @@ static int nbc_schedule_round_append (NBC_Schedule *schedule, void *data, int da
}

/* this function puts a send into the schedule */
static int NBC_Sched_send_internal (const void* buf, char tmpbuf, int count, MPI_Datatype datatype, int dest, bool local, NBC_Schedule *schedule, bool barrier) {
static int NBC_Sched_send_internal (const void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int dest, bool local, NBC_Schedule *schedule, bool barrier) {
NBC_Args_send send_args;
int ret;

Expand All @@ -143,16 +143,16 @@ static int NBC_Sched_send_internal (const void* buf, char tmpbuf, int count, MPI
return OMPI_SUCCESS;
}

int NBC_Sched_send (const void* buf, char tmpbuf, int count, MPI_Datatype datatype, int dest, NBC_Schedule *schedule, bool barrier) {
int NBC_Sched_send (const void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int dest, NBC_Schedule *schedule, bool barrier) {
return NBC_Sched_send_internal (buf, tmpbuf, count, datatype, dest, false, schedule, barrier);
}

int NBC_Sched_local_send (const void* buf, char tmpbuf, int count, MPI_Datatype datatype, int dest, NBC_Schedule *schedule, bool barrier) {
int NBC_Sched_local_send (const void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int dest, NBC_Schedule *schedule, bool barrier) {
return NBC_Sched_send_internal (buf, tmpbuf, count, datatype, dest, true, schedule, barrier);
}

/* this function puts a receive into the schedule */
static int NBC_Sched_recv_internal (void* buf, char tmpbuf, int count, MPI_Datatype datatype, int source, bool local, NBC_Schedule *schedule, bool barrier) {
static int NBC_Sched_recv_internal (void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int source, bool local, NBC_Schedule *schedule, bool barrier) {
NBC_Args_recv recv_args;
int ret;

Expand All @@ -176,16 +176,16 @@ static int NBC_Sched_recv_internal (void* buf, char tmpbuf, int count, MPI_Datat
return OMPI_SUCCESS;
}

int NBC_Sched_recv (void* buf, char tmpbuf, int count, MPI_Datatype datatype, int source, NBC_Schedule *schedule, bool barrier) {
int NBC_Sched_recv (void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int source, NBC_Schedule *schedule, bool barrier) {
return NBC_Sched_recv_internal(buf, tmpbuf, count, datatype, source, false, schedule, barrier);
}

int NBC_Sched_local_recv (void* buf, char tmpbuf, int count, MPI_Datatype datatype, int source, NBC_Schedule *schedule, bool barrier) {
int NBC_Sched_local_recv (void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int source, NBC_Schedule *schedule, bool barrier) {
return NBC_Sched_recv_internal(buf, tmpbuf, count, datatype, source, true, schedule, barrier);
}

/* this function puts an operation into the schedule */
int NBC_Sched_op (const void* buf1, char tmpbuf1, void* buf2, char tmpbuf2, int count, MPI_Datatype datatype,
int NBC_Sched_op (const void* buf1, char tmpbuf1, void* buf2, char tmpbuf2, size_t count, MPI_Datatype datatype,
MPI_Op op, NBC_Schedule *schedule, bool barrier) {
NBC_Args_op op_args;
int ret;
Expand All @@ -212,7 +212,8 @@ int NBC_Sched_op (const void* buf1, char tmpbuf1, void* buf2, char tmpbuf2, int
}

/* this function puts a copy into the schedule */
int NBC_Sched_copy (void *src, char tmpsrc, int srccount, MPI_Datatype srctype, void *tgt, char tmptgt, int tgtcount,
int NBC_Sched_copy (void *src, char tmpsrc, size_t srccount, MPI_Datatype srctype,
void *tgt, char tmptgt, size_t tgtcount,
MPI_Datatype tgttype, NBC_Schedule *schedule, bool barrier) {
NBC_Args_copy copy_args;
int ret;
Expand Down Expand Up @@ -240,7 +241,7 @@ int NBC_Sched_copy (void *src, char tmpsrc, int srccount, MPI_Datatype srctype,
}

/* this function puts a unpack into the schedule */
int NBC_Sched_unpack (void *inbuf, char tmpinbuf, int count, MPI_Datatype datatype, void *outbuf, char tmpoutbuf,
int NBC_Sched_unpack (void *inbuf, char tmpinbuf, size_t count, MPI_Datatype datatype, void *outbuf, char tmpoutbuf,
NBC_Schedule *schedule, bool barrier) {
NBC_Args_unpack unpack_args;
int ret;
Expand Down Expand Up @@ -534,6 +535,7 @@ static inline int NBC_Start_round(NBC_Handle *handle) {
} else {
buf2=opargs.buf2;
}

ompi_op_reduce(opargs.op, buf1, buf2, opargs.count, opargs.datatype);
break;
case COPY:
Expand Down
27 changes: 14 additions & 13 deletions ompi/mca/coll/libnbc/nbc_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* Copyright (c) 2015 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2018 FUJITSU LIMITED. All rights reserved.
* Copyright (c) 2021 IBM Corporation. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand Down Expand Up @@ -90,7 +91,7 @@ typedef enum {
/* the send argument struct */
typedef struct {
NBC_Fn_type type;
int count;
size_t count;
const void *buf;
MPI_Datatype datatype;
int dest;
Expand All @@ -101,7 +102,7 @@ typedef struct {
/* the receive argument struct */
typedef struct {
NBC_Fn_type type;
int count;
size_t count;
void *buf;
MPI_Datatype datatype;
char tmpbuf;
Expand All @@ -118,26 +119,26 @@ typedef struct {
void *buf2;
MPI_Op op;
MPI_Datatype datatype;
int count;
size_t count;
} NBC_Args_op;

/* the copy argument struct */
typedef struct {
NBC_Fn_type type;
int srccount;
size_t srccount;
void *src;
void *tgt;
MPI_Datatype srctype;
MPI_Datatype tgttype;
int tgtcount;
size_t tgtcount;
char tmpsrc;
char tmptgt;
} NBC_Args_copy;

/* unpack operation arguments */
typedef struct {
NBC_Fn_type type;
int count;
size_t count;
void *inbuf;
void *outbuf;
MPI_Datatype datatype;
Expand All @@ -146,15 +147,15 @@ typedef struct {
} NBC_Args_unpack;

/* internal function prototypes */
int NBC_Sched_send (const void* buf, char tmpbuf, int count, MPI_Datatype datatype, int dest, NBC_Schedule *schedule, bool barrier);
int NBC_Sched_local_send (const void* buf, char tmpbuf, int count, MPI_Datatype datatype, int dest,NBC_Schedule *schedule, bool barrier);
int NBC_Sched_recv (void* buf, char tmpbuf, int count, MPI_Datatype datatype, int source, NBC_Schedule *schedule, bool barrier);
int NBC_Sched_local_recv (void* buf, char tmpbuf, int count, MPI_Datatype datatype, int source, NBC_Schedule *schedule, bool barrier);
int NBC_Sched_op (const void* buf1, char tmpbuf1, void* buf2, char tmpbuf2, int count, MPI_Datatype datatype,
int NBC_Sched_send (const void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int dest, NBC_Schedule *schedule, bool barrier);
int NBC_Sched_local_send (const void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int dest,NBC_Schedule *schedule, bool barrier);
int NBC_Sched_recv (void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int source, NBC_Schedule *schedule, bool barrier);
int NBC_Sched_local_recv (void* buf, char tmpbuf, size_t count, MPI_Datatype datatype, int source, NBC_Schedule *schedule, bool barrier);
int NBC_Sched_op (const void* buf1, char tmpbuf1, void* buf2, char tmpbuf2, size_t count, MPI_Datatype datatype,
MPI_Op op, NBC_Schedule *schedule, bool barrier);
int NBC_Sched_copy (void *src, char tmpsrc, int srccount, MPI_Datatype srctype, void *tgt, char tmptgt, int tgtcount,
int NBC_Sched_copy (void *src, char tmpsrc, size_t srccount, MPI_Datatype srctype, void *tgt, char tmptgt, size_t tgtcount,
MPI_Datatype tgttype, NBC_Schedule *schedule, bool barrier);
int NBC_Sched_unpack (void *inbuf, char tmpinbuf, int count, MPI_Datatype datatype, void *outbuf, char tmpoutbuf,
int NBC_Sched_unpack (void *inbuf, char tmpinbuf, size_t count, MPI_Datatype datatype, void *outbuf, char tmpoutbuf,
NBC_Schedule *schedule, bool barrier);

int NBC_Sched_barrier (NBC_Schedule *schedule);
Expand Down
6 changes: 4 additions & 2 deletions ompi/mca/coll/libnbc/nbc_ireduce_scatter.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
static int nbc_reduce_scatter_init(const void* sendbuf, void* recvbuf, const int *recvcounts, MPI_Datatype datatype,
MPI_Op op, struct ompi_communicator_t *comm, ompi_request_t ** request,
mca_coll_base_module_t *module, bool persistent) {
int peer, rank, maxr, p, res, count;
int peer, rank, maxr, p, res;
size_t count;
MPI_Aint ext;
ptrdiff_t gap, span, span_align;
char *sbuf, inplace;
Expand Down Expand Up @@ -230,7 +231,8 @@ int ompi_coll_libnbc_ireduce_scatter (const void* sendbuf, void* recvbuf, const
static int nbc_reduce_scatter_inter_init (const void* sendbuf, void* recvbuf, const int *recvcounts, MPI_Datatype datatype,
MPI_Op op, struct ompi_communicator_t *comm, ompi_request_t ** request,
mca_coll_base_module_t *module, bool persistent) {
int rank, res, count, lsize, rsize;
int rank, res, lsize, rsize;
size_t count;
MPI_Aint ext;
ptrdiff_t gap, span, span_align;
NBC_Schedule *schedule;
Expand Down
6 changes: 4 additions & 2 deletions ompi/mca/coll/libnbc/nbc_ireduce_scatter_block.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
static int nbc_reduce_scatter_block_init(const void* sendbuf, void* recvbuf, int recvcount, MPI_Datatype datatype,
MPI_Op op, struct ompi_communicator_t *comm, ompi_request_t ** request,
mca_coll_base_module_t *module, bool persistent) {
int peer, rank, maxr, p, res, count;
int peer, rank, maxr, p, res;
size_t count;
MPI_Aint ext;
ptrdiff_t gap, span;
char *redbuf, *sbuf, inplace;
Expand Down Expand Up @@ -229,7 +230,8 @@ int ompi_coll_libnbc_ireduce_scatter_block(const void* sendbuf, void* recvbuf, i
static int nbc_reduce_scatter_block_inter_init(const void *sendbuf, void *recvbuf, int rcount, struct ompi_datatype_t *dtype,
struct ompi_op_t *op, struct ompi_communicator_t *comm, ompi_request_t **request,
mca_coll_base_module_t *module, bool persistent) {
int rank, res, count, lsize, rsize;
int rank, res, lsize, rsize;
size_t count;
MPI_Aint ext;
ptrdiff_t gap, span, span_align;
NBC_Schedule *schedule;
Expand Down
34 changes: 33 additions & 1 deletion ompi/op/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
* and Technology (RIST). All rights reserved.
* Copyright (c) 2018 Triad National Security, LLC. All rights
* reserved.
* Copyright (c) 2021 IBM Corporation. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand Down Expand Up @@ -510,10 +511,41 @@ static inline bool ompi_op_is_valid(ompi_op_t * op, ompi_datatype_t * ddt,
* is not defined to have that operation, it is likely to seg fault.
*/
static inline void ompi_op_reduce(ompi_op_t * op, void *source,
void *target, int count,
void *target, size_t full_count,
ompi_datatype_t * dtype)
{
MPI_Fint f_dtype, f_count;
int count = full_count;

/*
* If the full_count is > INT_MAX then we need to call the reduction op
* in iterations of counts <= INT_MAX since it has an `int *len`
* parameter.
*
* Note: When we add BigCount support then we can distinguish between
* a reduction operation with `int *len` and `MPI_Count *len`. At which
* point we can avoid this loop.
*/
if( OPAL_UNLIKELY(full_count > INT_MAX) ) {
size_t done_count = 0, shift;
int iter_count;
ptrdiff_t ext, lb;

ompi_datatype_get_extent(dtype, &lb, &ext);

while(done_count < full_count) {
if(done_count + INT_MAX > full_count) {
iter_count = full_count - done_count;
} else {
iter_count = INT_MAX;
}
shift = done_count * ext;
// Recurse one level in iterations of 'int'
ompi_op_reduce(op, (char*)source + shift, (char*)target + shift, iter_count, dtype);
done_count += iter_count;
}
return;
}

/*
* Call the reduction function. Two dimensions: a) if both the op
Expand Down