Skip to content

Commit 656c0c2

Browse files
ompi/coll/cuda: implement reduce_local
Signed-off-by: Akshay Venkatesh <akvenkatesh@nvidia.com> bot:notacherrypick
1 parent 778476f commit 656c0c2

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

ompi/mca/coll/cuda/coll_cuda.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*
2+
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
23
* Copyright (c) 2014 The University of Tennessee and The University
34
* of Tennessee Research Foundation. All rights
45
* reserved.
@@ -45,6 +46,11 @@ mca_coll_cuda_allreduce(const void *sbuf, void *rbuf, int count,
4546
struct ompi_communicator_t *comm,
4647
mca_coll_base_module_t *module);
4748

49+
int mca_coll_cuda_reduce_local(const void *sbuf, void *rbuf, int count,
50+
struct ompi_datatype_t *dtype,
51+
struct ompi_op_t *op,
52+
mca_coll_base_module_t *module);
53+
4854
int mca_coll_cuda_reduce(const void *sbuf, void *rbuf, int count,
4955
struct ompi_datatype_t *dtype,
5056
struct ompi_op_t *op,

ompi/mca/coll/cuda/coll_cuda_module.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*
2+
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
23
* Copyright (c) 2014-2017 The University of Tennessee and The University
34
* of Tennessee Research Foundation. All rights
45
* reserved.
@@ -104,6 +105,7 @@ mca_coll_cuda_comm_query(struct ompi_communicator_t *comm,
104105
cuda_module->super.coll_gather = NULL;
105106
cuda_module->super.coll_gatherv = NULL;
106107
cuda_module->super.coll_reduce = mca_coll_cuda_reduce;
108+
cuda_module->super.coll_reduce_local = mca_coll_cuda_reduce_local;
107109
cuda_module->super.coll_reduce_scatter = NULL;
108110
cuda_module->super.coll_reduce_scatter_block = mca_coll_cuda_reduce_scatter_block;
109111
cuda_module->super.coll_scan = mca_coll_cuda_scan;

ompi/mca/coll/cuda/coll_cuda_reduce.c

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*
2+
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
23
* Copyright (c) 2004-2015 The University of Tennessee and The University
34
* of Tennessee Research Foundation. All rights
45
* reserved.
@@ -34,7 +35,7 @@ mca_coll_cuda_reduce(const void *sbuf, void *rbuf, int count,
3435
mca_coll_base_module_t *module)
3536
{
3637
mca_coll_cuda_module_t *s = (mca_coll_cuda_module_t*) module;
37-
int rank = ompi_comm_rank(comm);
38+
int rank = (comm == NULL) ? -1 : ompi_comm_rank(comm);
3839
ptrdiff_t gap;
3940
char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL;
4041
const char *sbuf2;
@@ -64,9 +65,15 @@ mca_coll_cuda_reduce(const void *sbuf, void *rbuf, int count,
6465
rbuf2 = rbuf; /* save away original buffer */
6566
rbuf = rbuf1 - gap;
6667
}
67-
rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count,
68-
dtype, op, root, comm,
69-
s->c_coll.coll_reduce_module);
68+
69+
if ((comm != NULL) && (root == -1)) {
70+
ompi_op_reduce(op, (void *)sbuf, rbuf, count, dtype);
71+
rc = OMPI_SUCCESS;
72+
} else {
73+
rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count,
74+
dtype, op, root, comm,
75+
s->c_coll.coll_reduce_module);
76+
}
7077

7178
if (NULL != sbuf1) {
7279
free(sbuf1);
@@ -78,3 +85,12 @@ mca_coll_cuda_reduce(const void *sbuf, void *rbuf, int count,
7885
}
7986
return rc;
8087
}
88+
89+
int
90+
mca_coll_cuda_reduce_local(const void *sbuf, void *rbuf, int count,
91+
struct ompi_datatype_t *dtype,
92+
struct ompi_op_t *op,
93+
mca_coll_base_module_t *module)
94+
{
95+
return mca_coll_cuda_reduce(sbuf, rbuf, count, dtype, op, -1, NULL ,module);
96+
}

0 commit comments

Comments
 (0)