Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/mpid/ch4/netmod/ofi/ofi_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,15 @@ int MPIDI_OFI_am_rdma_read_ack_handler(void *am_hdr, void *data,
MPI_Aint in_data_sz, uint32_t attr, MPIR_Request ** req);
int MPIDI_OFI_rndv_info_handler(void *am_hdr, void *data, MPI_Aint data_sz,
uint32_t attr, MPIR_Request ** req);

int MPIDI_OFI_mirror_get(void *origin_addr, MPI_Aint origin_count, MPI_Datatype origin_datatype,
int target_rank, MPI_Aint target_disp, MPI_Aint target_count,
MPI_Datatype target_datatype, MPIR_Win * win);
int MPIDI_OFI_get_handler(void *am_hdr, void *data, MPI_Aint data_sz,
uint32_t attr, MPIR_Request ** req);
int MPIDI_OFI_getack_handler(void *am_hdr, void *data, MPI_Aint data_sz,
uint32_t attr, MPIR_Request ** req);

int MPIDI_OFI_control_dispatch(void *buf);
void MPIDI_OFI_index_datatypes(struct fid_ep *ep);
int MPIDI_OFI_mr_key_allocator_init(void);
Expand All @@ -307,6 +316,8 @@ void MPIDI_OFI_mr_key_allocator_destroy(void);
int MPIDI_OFI_datatype_to_ofi(MPI_Datatype dt, enum fi_datatype *fi_dt);
int MPIDI_OFI_op_to_ofi(MPI_Op op, enum fi_op *fi_op);

int MPIDI_OFI_rdmaread_poll(MPIX_Async_thing thing);

/* RMA */
#define MPIDI_OFI_INIT_CHUNK_CONTEXT(win,sigreq) \
do { \
Expand Down
2 changes: 2 additions & 0 deletions src/mpid/ch4/netmod/ofi/ofi_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,8 @@ static int am_init(int vci)
if (vci == 0) {
MPIDIG_am_reg_cb(MPIDI_OFI_AM_RDMA_READ_ACK, NULL, &MPIDI_OFI_am_rdma_read_ack_handler);
MPIDIG_am_reg_cb(MPIDI_OFI_RNDV_INFO, NULL, &MPIDI_OFI_rndv_info_handler);
MPIDIG_am_reg_cb(MPIDI_OFI_GET_REQ, NULL, &MPIDI_OFI_get_handler);
MPIDIG_am_reg_cb(MPIDI_OFI_GET_ACK, NULL, &MPIDI_OFI_getack_handler);
}
}

Expand Down
8 changes: 8 additions & 0 deletions src/mpid/ch4/netmod/ofi/ofi_pre.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,12 @@ typedef struct {

typedef struct {
MPIDI_OFI_RNDV_COMMON_FIELDS;
int num_nics;
MPI_Aint sz_per_nic;
union {
struct {
const void *data;
struct fid_mr *mr0;
struct fid_mr **mrs;
} send;
struct {
Expand All @@ -246,11 +248,14 @@ typedef struct {
int copy_infly; /* need_pack */
} u;
uint64_t remote_base;
uint64_t rkey0; /* avoid malloc when num_nics == 1 */
uint64_t *rkeys;
MPI_Aint chunks_per_nic;
MPI_Aint cur_chunk_index;
int num_infly;
bool all_issued;
int (*cmpl_cb) (void *context); /* context will be cast to (MPIR_Request *) */
void *context;
} recv;
} u;
} MPIDI_OFI_rndvread_t;
Expand Down Expand Up @@ -381,6 +386,9 @@ typedef struct {
struct MPIDI_OFI_win_request *syncQ;
struct MPIDI_OFI_win_request *deferredQ;
MPIDI_OFI_win_targetinfo_t *winfo;
void *mirror_buf; /* used in gpu fallback paths to avoid repeated host registration */
MPL_pointer_attr_t base_attr;
MPL_pointer_attr_t mirror_attr;

MPL_gavl_tree_t *dwin_target_mrs; /* MR key and address pairs registered to remote processes.
* One AVL tree per process. */
Expand Down
252 changes: 252 additions & 0 deletions src/mpid/ch4/netmod/ofi/ofi_rma.c
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,255 @@ int MPIDI_OFI_issue_deferred_rma(MPIR_Win * win)
fn_fail:
goto fn_exit;
}

/* -- active message fallback using mirror buffers -- */

/* assumptions:
* 1. both origin and target datatypes are contig
* 2. data_sz <= MPIDI_OFI_global.max_msg_size
*/

/* Get using AM mirror buffer -
* 1. Origin send am MPIDI_OFI_GET_REQ
* 2. Target async localcopy to mirror buffer
* 3. Target send am MPIDI_OFI_GET_ACK
* 4. Origin RDMA read
* 5. Origin complete
*/

struct get_context {
MPIR_Win *win;
int target_rank;
MPI_Aint data_sz;
void *origin_addr;
MPI_Aint target_offset;
MPIR_Request *req;
};

struct get_hdr {
uint64_t win_id;
int origin_rank;
void *origin_context;
MPI_Aint target_offset;
MPI_Aint data_sz;
};

/* origin side - issue AM req */
int MPIDI_OFI_mirror_get(void *origin_addr, MPI_Aint origin_count, MPI_Datatype origin_datatype,
int target_rank, MPI_Aint target_disp, MPI_Aint target_count,
MPI_Datatype target_datatype, MPIR_Win * win)
{
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

/* query target datatype */
int is_contig;
MPIR_Datatype_is_contig(target_datatype, &is_contig);

MPI_Aint data_sz;
MPIR_Datatype_get_size_macro(origin_datatype, data_sz);
data_sz *= origin_count;

MPI_Aint origin_true_lb, target_true_lb;
MPIR_Datatype_get_true_lb(target_datatype, &target_true_lb);
MPIR_Datatype_get_true_lb(origin_datatype, &origin_true_lb);

int vci = MPIDI_WIN(win, am_vci);
int vci_target = MPIDI_WIN_TARGET_VCI(win, target_rank);

/* fill origin context */
struct get_context *origin_context;
origin_context = MPL_malloc(sizeof(struct get_context), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP((origin_context == NULL), mpi_errno, MPI_ERR_OTHER, "**nomem");

origin_context->win = win;
origin_context->target_rank = target_rank;
origin_context->data_sz = data_sz;
origin_context->origin_addr = (char *) origin_addr + origin_true_lb;
origin_context->target_offset = target_disp * win->disp_unit + target_true_lb;

/* allocate a request, used for reuse the code from ofi_rndv_read. */
MPIR_Request *req;
MPIDI_OFI_REQUEST_CREATE(req, MPIR_REQUEST_KIND__RMA, vci);
if (1) {
MPIDI_CH4_REQUEST_FREE(req);
}
origin_context->req = req;

/* fill am_hdr */
struct get_hdr am_hdr;
am_hdr.win_id = MPIDIG_WIN(win, win_id);
am_hdr.origin_rank = win->comm_ptr->rank;
am_hdr.origin_context = origin_context;
am_hdr.data_sz = origin_context->data_sz;
am_hdr.target_offset = origin_context->target_offset;

MPIDIG_win_cmpl_cnts_incr(win, target_rank, NULL);

MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci));
mpi_errno = MPIDI_NM_am_send_hdr(target_rank, win->comm_ptr, MPIDI_OFI_GET_REQ,
&am_hdr, sizeof(am_hdr), vci, vci_target);
MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci));
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

struct target_mirror_copy {
MPIR_Win *win;
int origin_rank;
void *origin_context;
int vci_origin;
int vci_target;
MPIR_gpu_req async_req;
};

static int target_mirror_copy_poll(MPIX_Async_thing thing);

/* target side - AM callback */
int MPIDI_OFI_get_handler(void *am_hdr, void *data, MPI_Aint in_data_sz,
uint32_t attr, MPIR_Request ** req)
{
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

struct get_hdr *msg_hdr = am_hdr;

MPIR_Win *win;
win = (MPIR_Win *) MPIDIU_map_lookup(MPIDI_global.win_map, msg_hdr->win_id);
MPIR_Assert(win);

void *mirror_buf = MPIDI_OFI_WIN(win).mirror_buf;
void *mirror_attr = &(MPIDI_OFI_WIN(win).mirror_attr);
void *base_buf = win->base;
void *base_attr = &(MPIDI_OFI_WIN(win).base_attr);

/* async localcopy */
MPIR_gpu_req async_req;
int engine_type = MPIDI_OFI_gpu_get_send_engine_type();
mpi_errno = MPIR_Ilocalcopy_gpu(base_buf, msg_hdr->data_sz, MPIR_BYTE_INTERNAL,
msg_hdr->target_offset, base_attr,
mirror_buf, msg_hdr->data_sz, MPIR_BYTE_INTERNAL,
msg_hdr->target_offset, mirror_attr,
engine_type, 1, &async_req);
MPIR_ERR_CHECK(mpi_errno);

/* add async things */
struct target_mirror_copy *p = MPL_malloc(sizeof(struct target_mirror_copy), MPL_MEM_OTHER);
p->win = win;
p->origin_rank = msg_hdr->origin_rank;
p->origin_context = msg_hdr->origin_context;
p->vci_origin = MPIDIG_AM_ATTR_SRC_VCI(attr);
p->vci_target = MPIDIG_AM_ATTR_DST_VCI(attr);
p->async_req = async_req;

mpi_errno = MPIR_Async_things_add(target_mirror_copy_poll, p, NULL);
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

struct getack_hdr {
void *origin_context;
uint64_t rkey;
uint64_t remote_base;
};

/* target side - async callback */
static int target_mirror_copy_poll(MPIX_Async_thing thing)
{
struct target_mirror_copy *p = MPIR_Async_thing_get_state(thing);
int is_done;
MPIR_async_test(&p->async_req, &is_done);

if (is_done) {
/* send get_ack */
struct getack_hdr am_hdr;
am_hdr.origin_context = p->origin_context;
am_hdr.rkey = fi_mr_key(MPIDI_OFI_WIN(p->win).mr);
am_hdr.remote_base = (uintptr_t) MPIDI_OFI_WIN(p->win).mirror_buf;

int rc = MPIDI_NM_am_send_hdr(p->origin_rank, p->win->comm_ptr, MPIDI_OFI_GET_ACK,
&am_hdr, sizeof(am_hdr), p->vci_target, p->vci_origin);
MPIR_Assertp(rc == MPI_SUCCESS);

MPL_free(p);

return MPIX_ASYNC_DONE;
}

return MPIX_ASYNC_NOPROGRESS;
}

struct read_req {
char pad[MPIDI_REQUEST_HDR_SIZE];
struct fi_context context[MPIDI_OFI_CONTEXT_STRUCTS];
int event_id;
struct get_context *origin_context;
};

static int rdmaread_completion(void *context);

/* origin side - AM callback */
int MPIDI_OFI_getack_handler(void *am_hdr, void *data, MPI_Aint in_data_sz,
uint32_t attr, MPIR_Request ** req)
{
int mpi_errno = MPI_SUCCESS;

struct getack_hdr *msg_hdr = am_hdr;
struct get_context *origin_context = msg_hdr->origin_context;
MPIR_Win *win = origin_context->win;
int target_rank = origin_context->target_rank;
MPI_Aint target_offset = origin_context->target_offset;

MPIDI_OFI_rndvread_t *p = &MPIDI_OFI_AMREQ_READ(origin_context->req);
p->buf = origin_context->origin_addr;
p->count = origin_context->data_sz;
p->datatype = MPIR_BYTE_INTERNAL;

MPIR_GPU_query_pointer_attr(p->buf, &p->attr);
p->need_pack = MPL_gpu_attr_is_dev(&p->attr);

p->data_sz = p->remote_data_sz = origin_context->data_sz;
p->vci_local = MPIDI_WIN(win, am_vci);
p->vci_remote = MPIDI_WIN_TARGET_VCI(win, target_rank);
p->av = MPIDIU_win_rank_to_av(win, target_rank, MPIDI_WIN(win, winattr));

p->num_nics = 1;
if (MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) {
p->u.recv.remote_base = msg_hdr->remote_base + target_offset;
} else {
p->u.recv.remote_base = target_offset;
}
p->u.recv.rkeys = &p->u.recv.rkey0;
p->u.recv.rkey0 = msg_hdr->rkey;
p->u.recv.cmpl_cb = rdmaread_completion;
p->u.recv.context = origin_context;

mpi_errno = MPIR_Async_things_add(MPIDI_OFI_rdmaread_poll, origin_context->req, NULL);

return mpi_errno;
}

static int rdmaread_completion(void *context)
{
struct get_context *origin_context = context;

MPIR_Win *win = origin_context->win;
int target_rank = origin_context->target_rank;

MPIDIG_win_cmpl_cnts_decr(win, target_rank);

MPIDI_Request_complete_fast(origin_context->req);
MPL_free(origin_context);

return MPI_SUCCESS;
}
39 changes: 35 additions & 4 deletions src/mpid/ch4/netmod/ofi/ofi_rma.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,43 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_get(void *origin_addr,
MPIDI_winattr_t winattr)
{
int mpi_errno = MPI_SUCCESS;

MPIR_FUNC_ENTER;

/* TODO: move pre-checks to ch4_rma */
MPIDIG_RMA_OP_CHECK_SYNC(target_rank, win);

/* check early exit */
if (target_count == 0)
goto fn_exit;

if (target_rank == win->comm_ptr->rank) {
MPI_Aint offset;
offset = win->disp_unit * target_disp;
mpi_errno = MPIR_Localcopy((char *) win->base + offset, target_count, target_datatype,
origin_addr, origin_count, origin_datatype);
MPIR_ERR_CHECK(mpi_errno);
goto fn_exit;
}

if (!MPIDI_OFI_ENABLE_RMA || !(winattr & MPIDI_WINATTR_NM_REACHABLE) ||
!MPIDI_OFI_gpu_rma_enabled(origin_addr)) {
MPIDI_OFI_register_am_bufs();
mpi_errno = MPIDIG_mpi_get(origin_addr, origin_count, origin_datatype, target_rank,
target_disp, target_count, target_datatype, win);
MPI_Aint data_sz;
MPIDI_Datatype_check_size(target_datatype, target_count, data_sz);
bool good_size = (data_sz >= MPIDI_NM_am_eager_limit());
int origin_is_contig, target_is_contig;
MPIR_Datatype_is_contig(origin_datatype, &origin_is_contig);
MPIR_Datatype_is_contig(target_datatype, &target_is_contig);
/* for now, only optimize for large contig data */
if (origin_is_contig && target_is_contig && good_size && MPIR_CVAR_OFI_ENABLE_WIN_MIRROR) {
/* use mirror_buf optimization */
mpi_errno = MPIDI_OFI_mirror_get(origin_addr, origin_count, origin_datatype,
target_rank,
target_disp, target_count, target_datatype, win);
} else {
MPIDI_OFI_register_am_bufs();
mpi_errno = MPIDIG_mpi_get(origin_addr, origin_count, origin_datatype, target_rank,
target_disp, target_count, target_datatype, win);
}
goto fn_exit;
}

Expand All @@ -546,6 +575,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_get(void *origin_addr,
fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_rput(const void *origin_addr,
Expand Down
Loading