-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Scala] add CustomOp support #4118
Changes from 1 commit
1d6a263
1ea58d3
c92382b
8a4346c
a814af4
c69523c
9fb9394
8c4a3a6
bb202a7
a93b3ef
43002ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1719,6 +1719,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxRtcFree | |
std::unordered_map<std::string, jobject> globalOpPropMap; | ||
std::unordered_map<std::string, int> globalOpPropCountMap; | ||
std::unordered_map<std::string, jobject> globalOpMap; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder whether we have a better way to deal with the registry here, can we make the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what I actually tried to implement is that because when those callback function was called, it was called in C++ side, then inside the callback function it needs to find the specific java object then |
||
std::mutex mutex_opprop; | ||
std::mutex mutex_op; | ||
|
||
template<typename T> | ||
class Ref { | ||
|
@@ -1743,8 +1745,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister | |
|
||
const char *regName = env->GetStringUTFChars(jregName, 0); | ||
std::string key(regName); | ||
|
||
std::unique_lock<std::mutex> lock(mutex_opprop); | ||
globalOpPropMap.insert({ key, env->NewGlobalRef(jopProp) }); | ||
globalOpPropCountMap.insert({ key, 0 }); | ||
lock.unlock(); | ||
|
||
auto creatorLambda = [](const char *opType, const int numKwargs, | ||
const char **keys, const char **values, CustomOpPropInfo *ret) { | ||
|
@@ -2034,7 +2039,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister | |
} | ||
delete[] ts; | ||
|
||
std::unique_lock<std::mutex> lock(mutex_op); | ||
globalOpMap.insert({ key, env->NewGlobalRef(jOp) }); | ||
lock.unlock(); | ||
|
||
_jvm->DetachCurrentThread(); | ||
|
||
|
@@ -2129,6 +2136,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister | |
auto delEntry = [](void *state) { | ||
std::string key((char *)state); | ||
bool success = true; | ||
std::unique_lock<std::mutex> lock(mutex_op); | ||
if (globalOpMap.find(key) == globalOpMap.end()) { | ||
LOG(FATAL) << "op: " << key << " not found"; | ||
success = false; | ||
|
@@ -2142,6 +2150,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister | |
else ++it; | ||
} | ||
} | ||
lock.unlock(); | ||
return success; | ||
}; | ||
|
||
|
@@ -2159,6 +2168,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister | |
// del callback | ||
auto opPropDel = [](void *state) { | ||
std::string key((char *)state); | ||
std::unique_lock<std::mutex> lock(mutex_opprop); | ||
int count_prop = globalOpPropCountMap.at(key); | ||
if (count_prop < 2) { | ||
globalOpPropCountMap[key] = ++count_prop; | ||
|
@@ -2182,6 +2192,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister | |
else ++it; | ||
} | ||
} | ||
lock.unlock(); | ||
return success; | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to give comments to indicate what those global variables for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it