@@ -98,6 +98,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = {
98
98
};
99
99
#endif
100
100
101
+ unsigned
102
+ mca_spml_ucx_mem_map_flags_symmetric_rkey (struct mca_spml_ucx * spml_ucx )
103
+ {
104
+ #if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY
105
+ if (spml_ucx -> symmetric_rkey_max_count > 0 ) {
106
+ return UCP_MEM_MAP_SYMMETRIC_RKEY ;
107
+ }
108
+ #endif
109
+
110
+ return 0 ;
111
+ }
112
+
113
+ void mca_spml_ucx_rkey_store_init (mca_spml_ucx_rkey_store_t * store )
114
+ {
115
+ store -> array = NULL ;
116
+ store -> count = 0 ;
117
+ store -> size = 0 ;
118
+ }
119
+
120
+ void mca_spml_ucx_rkey_store_cleanup (mca_spml_ucx_rkey_store_t * store )
121
+ {
122
+ int i ;
123
+
124
+ for (i = 0 ; i < store -> count ; i ++ ) {
125
+ if (store -> array [i ].refcnt != 0 ) {
126
+ SPML_UCX_ERROR ("rkey store destroy: %d/%d has refcnt %d > 0" ,
127
+ i , store -> count , store -> array [i ].refcnt );
128
+ }
129
+
130
+ ucp_rkey_destroy (store -> array [i ].rkey );
131
+ }
132
+
133
+ free (store -> array );
134
+ }
135
+
136
+ /**
137
+ * Find position in sorted array for existing or future entry
138
+ *
139
+ * @param[in] store Store of the entries
140
+ * @param[in] worker Common worker for rkeys used
141
+ * @param[in] rkey Remote key to search for
142
+ * @param[out] index Index of entry
143
+ *
144
+ * @return
145
+ * OSHMEM_ERR_NOT_FOUND: index contains the position where future element
146
+ * should be inserted to keep array sorted
147
+ * OSHMEM_SUCCESS : index contains the position of the element
148
+ * Other error : index is not valid
149
+ */
150
+ static int mca_spml_ucx_rkey_store_find (const mca_spml_ucx_rkey_store_t * store ,
151
+ const ucp_worker_h worker ,
152
+ const ucp_rkey_h rkey ,
153
+ int * index )
154
+ {
155
+ #if HAVE_DECL_UCP_RKEY_COMPARE
156
+ ucp_rkey_compare_params_t params ;
157
+ int i , result , m , end ;
158
+ ucs_status_t status ;
159
+
160
+ for (i = 0 , end = store -> count ; i < end ;) {
161
+ m = (i + end ) / 2 ;
162
+
163
+ params .field_mask = 0 ;
164
+ status = ucp_rkey_compare (worker , store -> array [m ].rkey ,
165
+ rkey , & params , & result );
166
+ if (status != UCS_OK ) {
167
+ return OSHMEM_ERROR ;
168
+ } else if (result == 0 ) {
169
+ * index = m ;
170
+ return OSHMEM_SUCCESS ;
171
+ } else if (result > 0 ) {
172
+ end = m ;
173
+ } else {
174
+ i = m + 1 ;
175
+ }
176
+ }
177
+
178
+ * index = i ;
179
+ return OSHMEM_ERR_NOT_FOUND ;
180
+ #else
181
+ return OSHMEM_ERROR ;
182
+ #endif
183
+ }
184
+
185
+ static void mca_spml_ucx_rkey_store_insert (mca_spml_ucx_rkey_store_t * store ,
186
+ int i , ucp_rkey_h rkey )
187
+ {
188
+ int size ;
189
+ mca_spml_ucx_rkey_t * tmp ;
190
+
191
+ if (store -> count >= mca_spml_ucx .symmetric_rkey_max_count ) {
192
+ return ;
193
+ }
194
+
195
+ if (store -> count >= store -> size ) {
196
+ size = sshmem_ucx_min (sshmem_ucx_max (store -> size , 8 ) * 2 ,
197
+ mca_spml_ucx .symmetric_rkey_max_count );
198
+ tmp = realloc (store -> array , size * sizeof (* store -> array ));
199
+ if (tmp == NULL ) {
200
+ return ;
201
+ }
202
+
203
+ store -> array = tmp ;
204
+ store -> size = size ;
205
+ }
206
+
207
+ memmove (& store -> array [i + 1 ], & store -> array [i ],
208
+ (store -> count - i ) * sizeof (* store -> array ));
209
+ store -> array [i ].rkey = rkey ;
210
+ store -> array [i ].refcnt = 1 ;
211
+ store -> count ++ ;
212
+ return ;
213
+ }
214
+
215
+ /* Takes ownership of input ucp remote key */
216
+ static ucp_rkey_h mca_spml_ucx_rkey_store_get (mca_spml_ucx_rkey_store_t * store ,
217
+ ucp_worker_h worker ,
218
+ ucp_rkey_h rkey )
219
+ {
220
+ int ret , i ;
221
+
222
+ if (mca_spml_ucx .symmetric_rkey_max_count == 0 ) {
223
+ return rkey ;
224
+ }
225
+
226
+ ret = mca_spml_ucx_rkey_store_find (store , worker , rkey , & i );
227
+ if (ret == OSHMEM_SUCCESS ) {
228
+ ucp_rkey_destroy (rkey );
229
+ store -> array [i ].refcnt ++ ;
230
+ return store -> array [i ].rkey ;
231
+ }
232
+
233
+ if (ret == OSHMEM_ERR_NOT_FOUND ) {
234
+ mca_spml_ucx_rkey_store_insert (store , i , rkey );
235
+ }
236
+
237
+ return rkey ;
238
+ }
239
+
240
+ static void mca_spml_ucx_rkey_store_put (mca_spml_ucx_rkey_store_t * store ,
241
+ ucp_worker_h worker ,
242
+ ucp_rkey_h rkey )
243
+ {
244
+ mca_spml_ucx_rkey_t * entry ;
245
+ int ret , i ;
246
+
247
+ ret = mca_spml_ucx_rkey_store_find (store , worker , rkey , & i );
248
+ if (ret != OSHMEM_SUCCESS ) {
249
+ goto out ;
250
+ }
251
+
252
+ entry = & store -> array [i ];
253
+ assert (entry -> rkey == rkey );
254
+ if (-- entry -> refcnt > 0 ) {
255
+ return ;
256
+ }
257
+
258
+ memmove (& store -> array [i ], & store -> array [i + 1 ],
259
+ (store -> count - (i + 1 )) * sizeof (* store -> array ));
260
+ store -> count -- ;
261
+
262
+ out :
263
+ ucp_rkey_destroy (rkey );
264
+ }
265
+
101
266
int mca_spml_ucx_enable (bool enable )
102
267
{
103
268
SPML_UCX_VERBOSE (50 , "*** ucx ENABLED ****" );
@@ -212,6 +377,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
212
377
{
213
378
int rc ;
214
379
ucs_status_t err ;
380
+ ucp_rkey_h rkey ;
215
381
216
382
rc = mca_spml_ucx_ctx_mkey_new (ucx_ctx , pe , segno , ucx_mkey );
217
383
if (OSHMEM_SUCCESS != rc ) {
@@ -220,11 +386,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
220
386
}
221
387
222
388
if (mkey -> u .data ) {
223
- err = ucp_ep_rkey_unpack (ucx_ctx -> ucp_peers [pe ].ucp_conn , mkey -> u .data , & (( * ucx_mkey ) -> rkey ) );
389
+ err = ucp_ep_rkey_unpack (ucx_ctx -> ucp_peers [pe ].ucp_conn , mkey -> u .data , & rkey );
224
390
if (UCS_OK != err ) {
225
391
SPML_UCX_ERROR ("failed to unpack rkey: %s" , ucs_status_string (err ));
226
392
return OSHMEM_ERROR ;
227
393
}
394
+
395
+ if (!oshmem_proc_on_local_node (pe )) {
396
+ rkey = mca_spml_ucx_rkey_store_get (& ucx_ctx -> rkey_store , ucx_ctx -> ucp_worker [0 ], rkey );
397
+ }
398
+
399
+ (* ucx_mkey )-> rkey = rkey ;
400
+
228
401
rc = mca_spml_ucx_ctx_mkey_cache (ucx_ctx , mkey , segno , pe );
229
402
if (OSHMEM_SUCCESS != rc ) {
230
403
SPML_UCX_ERROR ("mca_spml_ucx_ctx_mkey_cache failed" );
@@ -239,7 +412,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
239
412
ucp_peer_t * ucp_peer ;
240
413
int rc ;
241
414
ucp_peer = & (ucx_ctx -> ucp_peers [pe ]);
242
- ucp_rkey_destroy ( ucx_mkey -> rkey );
415
+ mca_spml_ucx_rkey_store_put ( & ucx_ctx -> rkey_store , ucx_ctx -> ucp_worker [ 0 ], ucx_mkey -> rkey );
243
416
ucx_mkey -> rkey = NULL ;
244
417
rc = mca_spml_ucx_peer_mkey_cache_del (ucp_peer , segno );
245
418
if (OSHMEM_SUCCESS != rc ){
@@ -697,7 +870,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
697
870
UCP_MEM_MAP_PARAM_FIELD_FLAGS ;
698
871
mem_map_params .address = addr ;
699
872
mem_map_params .length = size ;
700
- mem_map_params .flags = flags ;
873
+ mem_map_params .flags = flags |
874
+ mca_spml_ucx_mem_map_flags_symmetric_rkey (& mca_spml_ucx );
701
875
702
876
status = ucp_mem_map (mca_spml_ucx .ucp_context , & mem_map_params , & mem_h );
703
877
if (UCS_OK != status ) {
@@ -887,6 +1061,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
887
1061
}
888
1062
}
889
1063
1064
+ mca_spml_ucx_rkey_store_init (& ucx_ctx -> rkey_store );
1065
+
890
1066
* ucx_ctx_p = ucx_ctx ;
891
1067
892
1068
return OSHMEM_SUCCESS ;
0 commit comments