Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

MXExecutorSimpleBindEx for backward compatability #7227

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1272,9 +1272,6 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
Expand All @@ -1289,6 +1286,41 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);

MXNET_DLL int MXExecutorSimpleBindEx(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
68 changes: 34 additions & 34 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,40 +1363,40 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
aux_state_handles = ctypes.POINTER(NDArrayHandle)()

try:
check_call(_LIB.MXExecutorSimpleBind(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_array(ctypes.c_char_p, provided_arg_shape_names),
c_array(mx_uint, provided_arg_shape_data),
c_array(mx_uint, provided_arg_shape_idx),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_array(ctypes.c_char_p, shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
check_call(_LIB.MXExecutorSimpleBindEx(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_array(ctypes.c_char_p, provided_arg_shape_names),
c_array(mx_uint, provided_arg_shape_data),
c_array(mx_uint, provided_arg_shape_idx),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_array(ctypes.c_char_p, shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
except MXNetError as e:
error_msg = "simple_bind error. Arguments:\n"
for k, v in kwargs.items():
Expand Down
116 changes: 110 additions & 6 deletions src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ int MXExecutorBindEX(SymbolHandle symbol_handle,
* \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes
* \param provided_arg_dtype_names argument name list of provided dtypes
* \param provided_arg_dtypes data of provided dtypes
* \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types
* \param provided_arg_stype_names argument name list of provided storage types
* \param provided_arg_stypes data of provided storage types
* \param num_shared_arg_names number of parameter names passed from _bind_ith_exec
* \param shared_arg_name_list parameter name list passed from _bind_ith_exec
* \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec
Expand Down Expand Up @@ -208,9 +205,6 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
Expand All @@ -225,6 +219,116 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out) {
const mx_uint num_provided_arg_stypes = 0;
const char** provided_arg_stype_names = nullptr;
const int* provided_arg_stypes = nullptr;
return MXExecutorSimpleBindEx(symbol_handle,
dev_type,
dev_id,
num_g2c_keys,
g2c_keys,
g2c_dev_types,
g2c_dev_ids,
provided_grad_req_list_len,
provided_grad_req_names,
provided_grad_req_types,
num_provided_arg_shapes,
provided_arg_shape_names,
provided_arg_shape_data,
provided_arg_shape_idx,
num_provided_arg_dtypes,
provided_arg_dtype_names,
provided_arg_dtypes,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stypes,
num_shared_arg_names,
shared_arg_name_list,
shared_buffer_len,
shared_buffer_name_list,
shared_buffer_handle_list,
updated_shared_buffer_name_list,
updated_shared_buffer_handle_list,
num_in_args,
in_args,
arg_grads,
num_aux_states,
aux_states,
shared_exec_handle,
out);
}

/*!
* \brief
* \param symbol_handle symbol handle
* \param dev_type default device type
* \param dev_id default device id
* \param num_g2c_keys number of group2ctx keys
* \param g2c_keys key list of group2ctx
* \param g2c_dev_types device type list of group2ctx
* \param g2c_dev_ids id list of group2ctx
* \param provided_grad_req_list_len grad_req length provided by users in front-end
* \param provided_grad_req_names grad_req names provided by users in front-end
* \param provided_grad_req_types req types provided by users in front-end
* \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes
* \param provided_arg_shape_names name list of provided shapes
* \param provided_arg_shape_data provided shape data
* \param provided_arg_shape_idx provided shape data index
* \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes
* \param provided_arg_dtype_names argument name list of provided dtypes
* \param provided_arg_dtypes data of provided dtypes
* \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types
* \param provided_arg_stype_names argument name list of provided storage types
* \param provided_arg_stypes data of provided storage types
* \param num_shared_arg_names number of parameter names passed from _bind_ith_exec
* \param shared_arg_name_list parameter name list passed from _bind_ith_exec
* \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec
* \param shared_buffer_name_list shared data array names passed from _bind_ith_exec
* \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec
* \param updated_shared_buffer_name_list updated shared data array names after binding
* \param updated_shared_buffer_handle_list updated shared data arrays after binding
* \param num_in_args number of input arguments of this sym
* \param in_args list_arguments associated with the current executor
* \param arg_grads list of gradients of in_args associated with the current executor
* \param num_aux_states number of aux states of this sym
* \param aux_states list_auxiliary_states associated with the current executor
* \param shared_exec_handle shared excutor handle passed from _bind_ith_exec
* \param out the handle of the executor to be created
*/
int MXExecutorSimpleBindEx(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** provided_arg_shape_names,
const mx_uint* provided_arg_shape_data,
const mx_uint* provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
mx_uint* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(symbol_handle);
Expand Down