Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 21 additions & 28 deletions ompi/mca/pml/ucx/pml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,12 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
mca_pml_ucx_get_datatype(datatype),
ucp_tag, ucp_tag_mask, req);

for (;;) {
MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) {
status = ucp_request_test(req, &info);
if (status != UCS_INPROGRESS) {
mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info);
return OMPI_SUCCESS;
}
opal_progress();
}
}

Expand Down Expand Up @@ -685,16 +684,12 @@ mca_pml_ucx_send_nb(ucp_ep_h ep, const void *buf, size_t count,

req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
mca_pml_ucx_get_datatype(datatype),
tag, mode,
mca_pml_ucx_send_completion);

tag, mode, cb);
if (OPAL_LIKELY(req == NULL)) {
return OMPI_SUCCESS;
} else if (!UCS_PTR_IS_ERR(req)) {
PML_UCX_VERBOSE(8, "got request %p", (void*)req);
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
ompi_request_wait(&req, MPI_STATUS_IGNORE);
return OMPI_SUCCESS;
MCA_COMMON_UCX_WAIT_LOOP(req, ompi_pml_ucx.ucp_worker, "ucx send", ompi_request_free(&req));
} else {
PML_UCX_ERROR("ucx send failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
return OMPI_ERROR;
Expand All @@ -707,7 +702,7 @@ mca_pml_ucx_send_nbr(ucp_ep_h ep, const void *buf, size_t count,
ucp_datatype_t ucx_datatype, ucp_tag_t tag)

{
void *req;
ucs_status_ptr_t req;
ucs_status_t status;

/* coverity[bad_alloc_arithmetic] */
Expand All @@ -717,12 +712,7 @@ mca_pml_ucx_send_nbr(ucp_ep_h ep, const void *buf, size_t count,
return OMPI_SUCCESS;
}

ucp_worker_progress(ompi_pml_ucx.ucp_worker);
while ((status = ucp_request_check_status(req)) == UCS_INPROGRESS) {
opal_progress();
}

return OPAL_LIKELY(UCS_OK == status) ? OMPI_SUCCESS : OMPI_ERROR;
MCA_COMMON_UCX_WAIT_LOOP(req, ompi_pml_ucx.ucp_worker, "ucx send", (void)0);
}
#endif

Expand Down Expand Up @@ -758,6 +748,8 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
int *matched, ompi_status_public_t* mpi_status)
{
static unsigned progress_count = 0;

ucp_tag_t ucp_tag, ucp_tag_mask;
ucp_tag_recv_info_t info;
ucp_tag_message_h ucp_msg;
Expand All @@ -770,8 +762,9 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
if (ucp_msg != NULL) {
*matched = 1;
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
} else {
opal_progress();
} else {
(++progress_count % opal_common_ucx.progress_iterations) ?
(void)ucp_worker_progress(ompi_pml_ucx.ucp_worker) : opal_progress();
*matched = 0;
}
return OMPI_SUCCESS;
Expand All @@ -787,22 +780,23 @@ int mca_pml_ucx_probe(int src, int tag, struct ompi_communicator_t* comm,
PML_UCX_TRACE_PROBE("probe", src, tag, comm);

PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
for (;;) {
ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask,
0, &info);

MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) {
ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag,
ucp_tag_mask, 0, &info);
if (ucp_msg != NULL) {
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
return OMPI_SUCCESS;
}

opal_progress();
}
}

int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
int *matched, struct ompi_message_t **message,
ompi_status_public_t* mpi_status)
int *matched, struct ompi_message_t **message,
ompi_status_public_t* mpi_status)
{
static unsigned progress_count = 0;

ucp_tag_t ucp_tag, ucp_tag_mask;
ucp_tag_recv_info_t info;
ucp_tag_message_h ucp_msg;
Expand All @@ -818,7 +812,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
*matched = 1;
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
} else {
opal_progress();
(++progress_count % opal_common_ucx.progress_iterations) ?
(void)ucp_worker_progress(ompi_pml_ucx.ucp_worker) : opal_progress();
*matched = 0;
}
return OMPI_SUCCESS;
Expand All @@ -835,7 +830,7 @@ int mca_pml_ucx_mprobe(int src, int tag, struct ompi_communicator_t* comm,
PML_UCX_TRACE_PROBE("mprobe", src, tag, comm);

PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
for (;;) {
MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) {
ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask,
1, &info);
if (ucp_msg != NULL) {
Expand All @@ -844,8 +839,6 @@ int mca_pml_ucx_mprobe(int src, int tag, struct ompi_communicator_t* comm,
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
return OMPI_SUCCESS;
}

opal_progress();
}
}

Expand Down
68 changes: 38 additions & 30 deletions opal/mca/common/ucx/common_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,33 @@ BEGIN_C_DECLS
__VA_ARGS__); \
}

/* progress loop to allow call UCX/opal progress */
/* used C99 for-statement variable initialization */
#define MCA_COMMON_UCX_PROGRESS_LOOP(_worker) \
for (unsigned iter = 0;; (++iter % opal_common_ucx.progress_iterations) ? \
(void)ucp_worker_progress(_worker) : opal_progress())

#define MCA_COMMON_UCX_WAIT_LOOP(_request, _worker, _msg, _completed) \
do { \
ucs_status_t status; \
/* call UCX progress */ \
MCA_COMMON_UCX_PROGRESS_LOOP(_worker) { \
status = opal_common_ucx_request_status(_request); \
if (UCS_INPROGRESS != status) { \
_completed; \
if (OPAL_LIKELY(UCS_OK == status)) { \
return OPAL_SUCCESS; \
} else { \
MCA_COMMON_UCX_VERBOSE(1, "%s failed: %d, %s", \
(_msg) ? (_msg) : __FUNCTION__, \
UCS_PTR_STATUS(_request), \
ucs_status_string(UCS_PTR_STATUS(_request))); \
return OPAL_ERROR; \
} \
} \
} \
} while (0)

typedef struct opal_common_ucx_module {
int output;
int verbose;
Expand All @@ -75,15 +102,21 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs(opal_common_ucx_del_proc_t *procs, s
size_t my_rank, size_t max_disconnect, ucp_worker_h worker);

static inline
int opal_common_ucx_wait_request(ucs_status_ptr_t request, ucp_worker_h worker,
const char *msg)
ucs_status_t opal_common_ucx_request_status(ucs_status_ptr_t request)
{
ucs_status_t status;
int i;
#if !HAVE_DECL_UCP_REQUEST_CHECK_STATUS
ucp_tag_recv_info_t info;

return ucp_request_test(request, &info);
#else
return ucp_request_check_status(request);
#endif
}

static inline
int opal_common_ucx_wait_request(ucs_status_ptr_t request, ucp_worker_h worker,
const char *msg)
{
/* check for request completed or failed */
if (OPAL_LIKELY(UCS_OK == request)) {
return OPAL_SUCCESS;
Expand All @@ -94,32 +127,7 @@ int opal_common_ucx_wait_request(ucs_status_ptr_t request, ucp_worker_h worker,
return OPAL_ERROR;
}

while (1) {
/* call UCX progress */
for (i = 0; i < opal_common_ucx.progress_iterations; i++) {
if (UCS_INPROGRESS != (status =
#if HAVE_DECL_UCP_REQUEST_CHECK_STATUS
ucp_request_check_status(request)
#else
ucp_request_test(request, &info)
#endif
)) {
ucp_request_free(request);
if (OPAL_LIKELY(UCS_OK == status)) {
return OPAL_SUCCESS;
} else {
MCA_COMMON_UCX_VERBOSE(1, "%s failed: %d, %s", msg ? msg : __func__,
UCS_PTR_STATUS(request),
ucs_status_string(UCS_PTR_STATUS(request)));
return OPAL_ERROR;
}
}
ucp_worker_progress(worker);
}
/* call OPAL progress on every opal_common_ucx_progress_iterations
* calls to UCX progress */
opal_progress();
}
MCA_COMMON_UCX_WAIT_LOOP(request, worker, msg, ucp_request_free(request));
}

static inline
Expand Down