This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
dynamic custom operator support #15921
Merged
Merged
Changes from 1 commit
Commits
Show all changes
139 commits
Select commit
Hold shift + click to select a range
5030a65
fixed example to use absolute path
23a226a
added example for custom ops, added support for custom op registration
67c22c0
added fcompute registration for loaded operators
915c1d5
changed dynamic ops to be contrib
f568e3d
added num in/out
8e12588
removed contrib op registration
1e27a47
added support for infer shape, updated example to call operator
9aecf86
fixed whitespace
02deacf
fixed whitespace
cf9350d
fixed whitespace
ada3895
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
38e77a5
added temporary support for operator multi-registration
7b8f6a2
insanity checked
5b817bd
update docblocks
rondogency 3bccfbe
small format fix
rondogency a8c19c8
fix unittest with correct library
rondogency 2f34471
implement InferType
rondogency 3502aa9
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
52e687b
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
5438a35
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
592249a
initial support for resource manager, temp space
3186d60
fixed formatting
e8b413b
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
bf549b4
changed decltype to typedef
439ee20
fixed whitespace
bba25db
Added windows declaration types, change APIs to return MXReturnValue …
a681f61
added library version number, API to get, and check to validate
711f9a3
Changed CMakeLists to build lib_ops instead of lib_api, updated lib_a…
172129f
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
5af1736
add prototype of subgraph op
rondogency 33d9cd7
implement FMutateInput as optional attribute
rondogency 4576570
fix sanity check
rondogency 6f3e3d9
replace fcompute to fcomputeEx and implement simple finferstoragetype
rondogency 9587483
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
34efb2b
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
ff9a868
changed fcompute to forward
0be218b
initial commit with fgradient support
570a059
enabled gradient registration
4b01932
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
e4be175
fixed whitespace
8cfcc85
fixed example to use absolute path
9884ec6
added example for custom ops, added support for custom op registration
8e21600
added fcompute registration for loaded operators
794e30b
changed dynamic ops to be contrib
8fbf664
added num in/out
e7c6e8f
removed contrib op registration
6047378
added support for infer shape, updated example to call operator
d1587ab
fixed whitespace
0ee56c9
fixed whitespace
adc9770
fixed whitespace
5c06d47
added temporary support for operator multi-registration
9136839
insanity checked
ffe7623
update docblocks
rondogency 435e01e
small format fix
rondogency 0de79a9
fix unittest with correct library
rondogency 0d6f7b0
implement InferType
rondogency 18b028e
initial support for resource manager, temp space
a4690b4
fixed formatting
c901828
changed decltype to typedef
5ddb919
fixed whitespace
7b4c4e6
Added windows declaration types, change APIs to return MXReturnValue …
18117ec
added library version number, API to get, and check to validate
ee65419
Changed CMakeLists to build lib_ops instead of lib_api, updated lib_a…
c66438c
add prototype of subgraph op
rondogency 698a0b6
implement FMutateInput as optional attribute
rondogency bd55612
fix sanity check
rondogency 35ff973
replace fcompute to fcomputeEx and implement simple finferstoragetype
rondogency f243e2f
changed fcompute to forward
efbb858
initial commit with fgradient support
0032143
enabled gradient registration
14ef3a7
fixed whitespace
eec71d6
prototype of createopstate and fstatefulcompute
rondogency abcb8cb
make custom state op interface work
rondogency 9cf0455
subgraph forward
rondogency 82f1bff
refactor stateful forward and add op resource
rondogency f7ff481
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
ba563d2
wip gemm backward
a9b7215
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
7bf4f7a
stateful backward and subgraph test
rondogency 8aec7ac
implement gemm and state gemm, refactor test files
rondogency 39e3d6b
add body to pure virtual destructor
rondogency b3ba028
subgraph passing from python to custom lib
rondogency c9d8498
Merge branch 'master' into dynamic_op
rondogency 1686273
rm lib_api c++11 dep, rm warpctc, add rm flag
rondogency 7009ad4
fix conflict
rondogency 4b73179
subgraph json parsing utility
rondogency dca521e
add data size and fix unsigned warnings
rondogency baed04e
use c++ struct and fix cpplint
rondogency aedcf91
refactor op registry
rondogency 75102a3
fix line length and win array of ci; condense lines
rondogency 9c29deb
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency c5a3ed6
add mxnet_extension dir
rondogency 44683f1
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
d1b6c8e
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
44affc7
fixed extension to be dll for windows
ef1d4cf
updated examples to use the same format as the example in the top-lev…
24d8cc3
removed destructor for CustomStatefulOp
279a989
fix error in gemm test and clear up subgraph test
rondogency 5db9e97
merge with dynamic_op
rondogency 75b1169
lib path fix
rondogency 79c0e3a
add unittest for custom op
rondogency 11d3344
update Makefile revolve merge
rondogency de157a8
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
28450b5
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
9504b33
fix test and rename folder
rondogency cf27d57
fix makefile rename
rondogency 7f456d4
fix cmake rename
rondogency e50819b
add explicit cpu context
rondogency 5984f3a
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency bd2c3a0
wkcn feedback: change mxtensor func name. use c++11 flag
rondogency b07e46b
add operator keyward test and refine info print
rondogency 2466d67
using typedef in forward
rondogency e041400
small refine of docblock
rondogency f16942c
change names
rondogency 50a6b64
add separate stateful compute and pass state_op ptr
rondogency adb0415
user example using opresource alloc
rondogency 6148ef8
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency 6d9ac54
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency 7c256cd
Merge remote-tracking branch 'upstream/master' into dynamic_op
rondogency 6e824fb
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
40c471b
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
dfb5946
added DLTensor into MXTensor
5146fd5
fixed whitespace
1b9fee2
added error check when DLTensor does not support MXNet data type
5761891
changed to throw runtime exception
ef840b4
changed include to stdexcept
bba61b3
retrigger CI
wkcn 53d18ec
empty commit
e0c778c
Merge branch 'dynamic_op' of https://github.com/samskalicky/incubator…
141328f
empty commit
deacae2
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
2b2c6a4
Merge branch 'master' into dynamic_op
szha ed8ac16
remove merge conflict
rondogency 56b0e28
add setdltensor for easy use and add docs
rondogency 1bd166e
Merge branch 'master' into dynamic_op
wkcn 50c8aea
CI
wkcn 34a9ee9
re-trigger CI
wkcn 9910c39
ci
wkcn 5fd4314
ci
wkcn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add separate stateful compute and pass state_op ptr
- Loading branch information
commit 50a6b6425707349726ede08944594c8ebd2b2151
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,13 +138,16 @@ int MXLoadLib(const char *path) { | |
opCallCreateOpState_t callCreateOpState = | ||
get_func<opCallCreateOpState_t>(lib, const_cast<char*>(MXLIB_OPCALLCREATEOPSTATE_STR)); | ||
|
||
opCallFStatefulComp_t callFStatefulComp = | ||
get_func<opCallFStatefulComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFSTATEFULCOMP_STR)); | ||
|
||
// get number of operators registered in the library | ||
opRegSize_t opRegSize = get_func<opRegSize_t>(lib, const_cast<char*>(MXLIB_OPREGSIZE_STR)); | ||
int numOps = opRegSize(); | ||
LOG(INFO) << "Found " << numOps << " operators in library"; | ||
|
||
/* | ||
* The library has custom operators implementation | ||
* Get all custom operators implementation from custom library | ||
* loop and register each operator in the library to NNVM | ||
*/ | ||
opRegGet_t opRegGet = get_func<opRegGet_t>(lib, const_cast<char*>(MXLIB_OPREGGET_STR)); | ||
|
@@ -368,11 +371,11 @@ int MXLoadLib(const char *path) { | |
|
||
// lambda function to convert from external fcompute to internal MXNet types | ||
auto fcomp_lambda = [=](fcomp_t fcomp_fp, | ||
const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
// convert attributes to vector of char* | ||
std::vector<const char*> attr_keys, attr_vals; | ||
for (auto kv : attrs.dict) { | ||
|
@@ -406,6 +409,7 @@ int MXLoadLib(const char *path) { | |
mshadow::Stream<mxnet::cpu> *cpu_stream = ctx.get_stream<mxnet::cpu>(); | ||
|
||
// create lambda that captures stream & resource objects | ||
// the memory pointer returned will eventually return to user | ||
auto cpu_alloc = [&](int size) { | ||
mshadow::Tensor<mxnet::cpu, 1, char> data = | ||
resource.get_space_typed<mxnet::cpu, 1, char>(mshadow::Shape1(size), cpu_stream); | ||
|
@@ -418,6 +422,7 @@ int MXLoadLib(const char *path) { | |
auto cpu_malloc = [](void* _cpu_alloc, int size) { | ||
// cast the void* argument to the type for the cpu_alloc lambda function | ||
alloc_type* cpualloc = static_cast<alloc_type*>(_cpu_alloc); | ||
// call cpu_alloc to actually allocate memory and get the pointer | ||
void* ptr = (*cpualloc)(size); | ||
return ptr; | ||
}; | ||
|
@@ -547,37 +552,39 @@ int MXLoadLib(const char *path) { | |
}; | ||
|
||
// stateful forward and backward | ||
auto fstateful_lambda = [=](bool forward, | ||
auto fstateful_lambda = [=](bool is_forward, | ||
const OpStatePtr& state_ptr, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
// create a vector of tensors for inputs | ||
std::vector<MXTensor> c_inputs(inputs.size()); | ||
std::vector<void*> in_data, out_data; | ||
std::vector<const int64_t *> in_shapes, out_shapes; | ||
std::vector<int> in_dims, out_dims; | ||
std::vector<int> in_types, out_types; | ||
|
||
// convert input tensors to constituent parts | ||
for (size_t i = 0; i < inputs.size(); i++) { | ||
c_inputs[i].data_ptr = inputs[i].data().dptr_; | ||
c_inputs[i].dtype = (MXDType)inputs[i].dtype(); | ||
for (int_least16_t j = 0; j < inputs[i].shape().ndim(); j++) { | ||
c_inputs[i].shape.push_back(inputs[i].shape().data()[j]); | ||
} | ||
in_data.push_back(inputs[i].data().dptr_); | ||
in_shapes.push_back(inputs[i].shape().data()); | ||
in_dims.push_back(inputs[i].shape().ndim()); | ||
in_types.push_back(inputs[i].dtype()); | ||
} | ||
|
||
// create a vector of tensors for outputs | ||
std::vector<MXTensor> c_outputs(outputs.size()); | ||
// convert output tensors to constituent parts | ||
for (size_t i = 0; i < outputs.size(); i++) { | ||
c_outputs[i].data_ptr = outputs[i].data().dptr_; | ||
c_outputs[i].dtype = (MXDType)outputs[i].dtype(); | ||
for (int j = 0; j < outputs[i].shape().ndim(); j++) { | ||
c_outputs[i].shape.push_back(outputs[i].shape().data()[j]); | ||
} | ||
out_data.push_back(outputs[i].data().dptr_); | ||
out_shapes.push_back(outputs[i].shape().data()); | ||
out_dims.push_back(outputs[i].shape().ndim()); | ||
out_types.push_back(outputs[i].dtype()); | ||
} | ||
|
||
// get memory resource | ||
const Resource &resource = ctx.requested[0]; | ||
mshadow::Stream<mxnet::cpu> *cpu_stream = ctx.get_stream<mxnet::cpu>(); | ||
|
||
// create lambda that captures stream & resource objects | ||
// the memory pointer returned will eventually return to user | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: consider rephrasing |
||
auto cpu_alloc = [&](int size) { | ||
mshadow::Tensor<mxnet::cpu, 1, char> data = | ||
resource.get_space_typed<mxnet::cpu, 1, char>(mshadow::Shape1(size), cpu_stream); | ||
|
@@ -590,26 +597,23 @@ int MXLoadLib(const char *path) { | |
auto cpu_malloc = [](void* _cpu_alloc, int size) { | ||
// cast the void* argument to the type for the cpu_alloc lambda function | ||
alloc_type* cpualloc = static_cast<alloc_type*>(_cpu_alloc); | ||
// call cpu_alloc to actually allocate memory and get the pointer | ||
void* ptr = (*cpualloc)(size); | ||
return ptr; | ||
}; | ||
|
||
OpResource op_res(cpu_malloc, &cpu_alloc); | ||
|
||
// retrieve op state object created from CreateOpState | ||
CustomStatefulOpWrapper& op = state_ptr.get_state<CustomStatefulOpWrapper>(); | ||
CustomStatefulOp* state_op_inst = op.get_instance(); | ||
CHECK(state_op_inst != nullptr) | ||
<< "Error MXNet cannot load custom stateful operator'" << name_str << "'"; | ||
|
||
if (forward) { | ||
CHECK(state_op_inst->Forward(c_inputs, c_outputs, op_res)) | ||
<< "Error calling ForwardStateful for custom operator '" << name_str << "'"; | ||
} else { | ||
CHECK(state_op_inst->Backward(c_inputs, c_outputs, op_res)) | ||
<< "Error calling BackwardStateful for custom operator '" << name_str << "'"; | ||
} | ||
// return type void | ||
// call fcompute function | ||
CHECK(callFStatefulComp(is_forward, state_op_inst, in_shapes.data(), in_dims.data(), | ||
in_data.data(), in_types.data(), in_data.size(), | ||
out_shapes.data(), out_dims.data(), out_data.data(), | ||
out_types.data(), out_data.size(), cpu_malloc, &cpu_alloc)) | ||
<< "Error calling FStatefulCompute for custom operator '" << name_str << "'"; | ||
}; | ||
|
||
auto fstateful_forward = [=](const OpStatePtr& state_ptr, | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consider rephrasing