Skip to content

Commit 386b88d

Browse files
authored
Merge pull request #11979 from tvegas1/smkey_store_v4.1.x
oshmem: Add symmetric remote key handling code
2 parents b2c910a + 937abaf commit 386b88d

File tree

6 files changed

+243
-31
lines changed

6 files changed

+243
-31
lines changed

config/ompi_check_ucx.m4

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
140140
UCP_PARAM_FIELD_ESTIMATED_NUM_PPN,
141141
UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK,
142142
UCP_OP_ATTR_FLAG_MULTI_SEND,
143-
UCS_MEMORY_TYPE_RDMA],
143+
UCS_MEMORY_TYPE_RDMA,
144+
UCP_MEM_MAP_SYMMETRIC_RKEY],
144145
[], [],
145146
[#include <ucp/api/ucp.h>])
146147
AC_CHECK_DECLS([UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS],
@@ -153,7 +154,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
153154
[#include <ucp/api/ucp.h>])
154155
AC_CHECK_DECLS([ucp_tag_send_nbx,
155156
ucp_tag_send_sync_nbx,
156-
ucp_tag_recv_nbx],
157+
ucp_tag_recv_nbx,
158+
ucp_rkey_compare],
157159
[], [],
158160
[#include <ucp/api/ucp.h>])
159161
AC_CHECK_TYPES([ucp_request_param_t],

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 179 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = {
9898
};
9999
#endif
100100

101+
unsigned
102+
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx)
103+
{
104+
#if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY
105+
if (spml_ucx->symmetric_rkey_max_count > 0) {
106+
return UCP_MEM_MAP_SYMMETRIC_RKEY;
107+
}
108+
#endif
109+
110+
return 0;
111+
}
112+
113+
void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store)
114+
{
115+
store->array = NULL;
116+
store->count = 0;
117+
store->size = 0;
118+
}
119+
120+
void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store)
121+
{
122+
int i;
123+
124+
for (i = 0; i < store->count; i++) {
125+
if (store->array[i].refcnt != 0) {
126+
SPML_UCX_ERROR("rkey store destroy: %d/%d has refcnt %d > 0",
127+
i, store->count, store->array[i].refcnt);
128+
}
129+
130+
ucp_rkey_destroy(store->array[i].rkey);
131+
}
132+
133+
free(store->array);
134+
}
135+
136+
/**
137+
* Find position in sorted array for existing or future entry
138+
*
139+
* @param[in] store Store of the entries
140+
* @param[in] worker Common worker for rkeys used
141+
* @param[in] rkey Remote key to search for
142+
* @param[out] index Index of entry
143+
*
144+
* @return
145+
* OSHMEM_ERR_NOT_FOUND: index contains the position where future element
146+
* should be inserted to keep array sorted
147+
* OSHMEM_SUCCESS : index contains the position of the element
148+
* Other error : index is not valid
149+
*/
150+
static int mca_spml_ucx_rkey_store_find(const mca_spml_ucx_rkey_store_t *store,
151+
const ucp_worker_h worker,
152+
const ucp_rkey_h rkey,
153+
int *index)
154+
{
155+
#if HAVE_DECL_UCP_RKEY_COMPARE
156+
ucp_rkey_compare_params_t params;
157+
int i, result, m, end;
158+
ucs_status_t status;
159+
160+
for (i = 0, end = store->count; i < end;) {
161+
m = (i + end) / 2;
162+
163+
params.field_mask = 0;
164+
status = ucp_rkey_compare(worker, store->array[m].rkey,
165+
rkey, &params, &result);
166+
if (status != UCS_OK) {
167+
return OSHMEM_ERROR;
168+
} else if (result == 0) {
169+
*index = m;
170+
return OSHMEM_SUCCESS;
171+
} else if (result > 0) {
172+
end = m;
173+
} else {
174+
i = m + 1;
175+
}
176+
}
177+
178+
*index = i;
179+
return OSHMEM_ERR_NOT_FOUND;
180+
#else
181+
return OSHMEM_ERROR;
182+
#endif
183+
}
184+
185+
static void mca_spml_ucx_rkey_store_insert(mca_spml_ucx_rkey_store_t *store,
186+
int i, ucp_rkey_h rkey)
187+
{
188+
int size;
189+
mca_spml_ucx_rkey_t *tmp;
190+
191+
if (store->count >= mca_spml_ucx.symmetric_rkey_max_count) {
192+
return;
193+
}
194+
195+
if (store->count >= store->size) {
196+
size = sshmem_ucx_min(sshmem_ucx_max(store->size, 8) * 2,
197+
mca_spml_ucx.symmetric_rkey_max_count);
198+
tmp = realloc(store->array, size * sizeof(*store->array));
199+
if (tmp == NULL) {
200+
return;
201+
}
202+
203+
store->array = tmp;
204+
store->size = size;
205+
}
206+
207+
memmove(&store->array[i + 1], &store->array[i],
208+
(store->count - i) * sizeof(*store->array));
209+
store->array[i].rkey = rkey;
210+
store->array[i].refcnt = 1;
211+
store->count++;
212+
return;
213+
}
214+
215+
/* Takes ownership of input ucp remote key */
216+
static ucp_rkey_h mca_spml_ucx_rkey_store_get(mca_spml_ucx_rkey_store_t *store,
217+
ucp_worker_h worker,
218+
ucp_rkey_h rkey)
219+
{
220+
int ret, i;
221+
222+
if (mca_spml_ucx.symmetric_rkey_max_count == 0) {
223+
return rkey;
224+
}
225+
226+
ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
227+
if (ret == OSHMEM_SUCCESS) {
228+
ucp_rkey_destroy(rkey);
229+
store->array[i].refcnt++;
230+
return store->array[i].rkey;
231+
}
232+
233+
if (ret == OSHMEM_ERR_NOT_FOUND) {
234+
mca_spml_ucx_rkey_store_insert(store, i, rkey);
235+
}
236+
237+
return rkey;
238+
}
239+
240+
static void mca_spml_ucx_rkey_store_put(mca_spml_ucx_rkey_store_t *store,
241+
ucp_worker_h worker,
242+
ucp_rkey_h rkey)
243+
{
244+
mca_spml_ucx_rkey_t *entry;
245+
int ret, i;
246+
247+
ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
248+
if (ret != OSHMEM_SUCCESS) {
249+
goto out;
250+
}
251+
252+
entry = &store->array[i];
253+
assert(entry->rkey == rkey);
254+
if (--entry->refcnt > 0) {
255+
return;
256+
}
257+
258+
memmove(&store->array[i], &store->array[i + 1],
259+
(store->count - (i + 1)) * sizeof(*store->array));
260+
store->count--;
261+
262+
out:
263+
ucp_rkey_destroy(rkey);
264+
}
265+
101266
int mca_spml_ucx_enable(bool enable)
102267
{
103268
SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****");
@@ -212,6 +377,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
212377
{
213378
int rc;
214379
ucs_status_t err;
380+
ucp_rkey_h rkey;
215381

216382
rc = mca_spml_ucx_ctx_mkey_new(ucx_ctx, pe, segno, ucx_mkey);
217383
if (OSHMEM_SUCCESS != rc) {
@@ -220,11 +386,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
220386
}
221387

222388
if (mkey->u.data) {
223-
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &((*ucx_mkey)->rkey));
389+
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &rkey);
224390
if (UCS_OK != err) {
225391
SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err));
226392
return OSHMEM_ERROR;
227393
}
394+
395+
if (!oshmem_proc_on_local_node(pe)) {
396+
rkey = mca_spml_ucx_rkey_store_get(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], rkey);
397+
}
398+
399+
(*ucx_mkey)->rkey = rkey;
400+
228401
rc = mca_spml_ucx_ctx_mkey_cache(ucx_ctx, mkey, segno, pe);
229402
if (OSHMEM_SUCCESS != rc) {
230403
SPML_UCX_ERROR("mca_spml_ucx_ctx_mkey_cache failed");
@@ -239,7 +412,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
239412
ucp_peer_t *ucp_peer;
240413
int rc;
241414
ucp_peer = &(ucx_ctx->ucp_peers[pe]);
242-
ucp_rkey_destroy(ucx_mkey->rkey);
415+
mca_spml_ucx_rkey_store_put(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], ucx_mkey->rkey);
243416
ucx_mkey->rkey = NULL;
244417
rc = mca_spml_ucx_peer_mkey_cache_del(ucp_peer, segno);
245418
if(OSHMEM_SUCCESS != rc){
@@ -697,7 +870,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
697870
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
698871
mem_map_params.address = addr;
699872
mem_map_params.length = size;
700-
mem_map_params.flags = flags;
873+
mem_map_params.flags = flags |
874+
mca_spml_ucx_mem_map_flags_symmetric_rkey(&mca_spml_ucx);
701875

702876
status = ucp_mem_map(mca_spml_ucx.ucp_context, &mem_map_params, &mem_h);
703877
if (UCS_OK != status) {
@@ -887,6 +1061,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
8871061
}
8881062
}
8891063

1064+
mca_spml_ucx_rkey_store_init(&ucx_ctx->rkey_store);
1065+
8901066
*ucx_ctx_p = ucx_ctx;
8911067

8921068
return OSHMEM_SUCCESS;

oshmem/mca/spml/ucx/spml_ucx.h

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,31 @@ struct ucp_peer {
7777
size_t mkeys_cnt;
7878
};
7979
typedef struct ucp_peer ucp_peer_t;
80-
80+
81+
/* An rkey_store entry */
82+
typedef struct mca_spml_ucx_rkey {
83+
ucp_rkey_h rkey;
84+
int refcnt;
85+
} mca_spml_ucx_rkey_t;
86+
87+
typedef struct mca_spml_ucx_rkey_store {
88+
mca_spml_ucx_rkey_t *array;
89+
int size;
90+
int count;
91+
} mca_spml_ucx_rkey_store_t;
92+
8193
struct mca_spml_ucx_ctx {
82-
ucp_worker_h *ucp_worker;
83-
ucp_peer_t *ucp_peers;
84-
long options;
85-
opal_bitmap_t put_op_bitmap;
86-
unsigned long nb_progress_cnt;
87-
unsigned int ucp_workers;
88-
int *put_proc_indexes;
89-
unsigned put_proc_count;
90-
bool synchronized_quiet;
91-
int strong_sync;
94+
ucp_worker_h *ucp_worker;
95+
ucp_peer_t *ucp_peers;
96+
long options;
97+
opal_bitmap_t put_op_bitmap;
98+
unsigned long nb_progress_cnt;
99+
unsigned int ucp_workers;
100+
int *put_proc_indexes;
101+
unsigned put_proc_count;
102+
bool synchronized_quiet;
103+
int strong_sync;
104+
mca_spml_ucx_rkey_store_t rkey_store;
92105
};
93106
typedef struct mca_spml_ucx_ctx mca_spml_ucx_ctx_t;
94107

@@ -129,6 +142,7 @@ struct mca_spml_ucx {
129142
unsigned long nb_ucp_worker_progress;
130143
unsigned int ucp_workers;
131144
unsigned int ucp_worker_cnt;
145+
int symmetric_rkey_max_count;
132146
};
133147
typedef struct mca_spml_ucx mca_spml_ucx_t;
134148

@@ -217,6 +231,12 @@ int mca_spml_ucx_peer_mkey_cache_del(ucp_peer_t *ucp_peer, int segno);
217231
void mca_spml_ucx_peer_mkey_cache_release(ucp_peer_t *ucp_peer);
218232
void mca_spml_ucx_peer_mkey_cache_init(mca_spml_ucx_ctx_t *ucx_ctx, int pe);
219233

234+
extern unsigned
235+
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx);
236+
237+
extern void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store);
238+
extern void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store);
239+
220240
static inline int
221241
mca_spml_ucx_peer_mkey_get(ucp_peer_t *ucp_peer, int index, spml_ucx_cached_mkey_t **out_rmkey)
222242
{

oshmem/mca/spml/ucx/spml_ucx_component.c

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ static int mca_spml_ucx_component_register(void)
153153
"Enable asynchronous progress thread",
154154
&mca_spml_ucx.async_progress);
155155

156+
mca_spml_ucx_param_register_int("symmetric_rkey_max_count", 0,
157+
"Size of the symmetric key store. Non-zero to enable, typical use 5000",
158+
&mca_spml_ucx.symmetric_rkey_max_count);
159+
156160
mca_spml_ucx_param_register_int("async_tick_usec", 3000,
157161
"Asynchronous progress tick granularity (in usec)",
158162
&mca_spml_ucx.async_tick);
@@ -332,6 +336,8 @@ static int spml_ucx_init(void)
332336
mca_spml_ucx_ctx_default.ucp_workers++;
333337
}
334338

339+
mca_spml_ucx_rkey_store_init(&mca_spml_ucx_ctx_default.rkey_store);
340+
335341
wrk_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE;
336342
err = ucp_worker_query(mca_spml_ucx_ctx_default.ucp_worker[0], &wrk_attr);
337343

@@ -432,10 +438,25 @@ static void _ctx_cleanup(mca_spml_ucx_ctx_t *ctx)
432438
free(ctx->ucp_peers);
433439
}
434440

441+
static void mca_spml_ucx_ctx_fini(mca_spml_ucx_ctx_t *ctx)
442+
{
443+
unsigned int i;
444+
445+
mca_spml_ucx_rkey_store_cleanup(&ctx->rkey_store);
446+
for (i = 0; i < ctx->ucp_workers; i++) {
447+
ucp_worker_destroy(ctx->ucp_worker[i]);
448+
}
449+
free(ctx->ucp_worker);
450+
if (ctx != &mca_spml_ucx_ctx_default) {
451+
free(ctx);
452+
}
453+
}
454+
435455
static int mca_spml_ucx_component_fini(void)
436456
{
437457
int fenced = 0, i;
438458
int ret = OSHMEM_SUCCESS;
459+
mca_spml_ucx_ctx_t *ctx;
439460

440461
opal_progress_unregister(spml_ucx_default_progress);
441462
if (mca_spml_ucx.active_array.ctxs_count) {
@@ -488,36 +509,26 @@ static int mca_spml_ucx_component_fini(void)
488509
}
489510
}
490511

491-
/* delete all workers */
492512
for (i = 0; i < mca_spml_ucx.active_array.ctxs_count; i++) {
493-
ucp_worker_destroy(mca_spml_ucx.active_array.ctxs[i]->ucp_worker[0]);
494-
free(mca_spml_ucx.active_array.ctxs[i]->ucp_worker);
495-
free(mca_spml_ucx.active_array.ctxs[i]);
513+
mca_spml_ucx_ctx_fini(mca_spml_ucx.active_array.ctxs[i]);
496514
}
497515

498516
for (i = 0; i < mca_spml_ucx.idle_array.ctxs_count; i++) {
499-
ucp_worker_destroy(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker[0]);
500-
free(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker);
501-
free(mca_spml_ucx.idle_array.ctxs[i]);
517+
mca_spml_ucx_ctx_fini(mca_spml_ucx.idle_array.ctxs[i]);
502518
}
503519

504520
if (mca_spml_ucx_ctx_default.ucp_worker) {
505-
for (i = 0; i < (signed int)mca_spml_ucx.ucp_workers; i++) {
506-
ucp_worker_destroy(mca_spml_ucx_ctx_default.ucp_worker[i]);
507-
}
508-
free(mca_spml_ucx_ctx_default.ucp_worker);
521+
mca_spml_ucx_ctx_fini(&mca_spml_ucx_ctx_default);
509522
}
510523

511524
if (mca_spml_ucx.aux_ctx != NULL) {
512-
ucp_worker_destroy(mca_spml_ucx.aux_ctx->ucp_worker[0]);
513-
free(mca_spml_ucx.aux_ctx->ucp_worker);
525+
mca_spml_ucx_ctx_fini(mca_spml_ucx.aux_ctx);
514526
}
515527

516528
mca_spml_ucx.enabled = false; /* not anymore */
517529

518530
free(mca_spml_ucx.active_array.ctxs);
519531
free(mca_spml_ucx.idle_array.ctxs);
520-
free(mca_spml_ucx.aux_ctx);
521532

522533
SHMEM_MUTEX_DESTROY(mca_spml_ucx.internal_mutex);
523534
pthread_mutex_destroy(&mca_spml_ucx.ctx_create_mutex);

0 commit comments

Comments
 (0)