Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Scala] add CustomOp support #4118

Merged
merged 11 commits into from
Dec 15, 2016
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix thread safty problem
  • Loading branch information
Ldpe2G committed Dec 15, 2016
commit 8c4a3a6ac6c2d7b093441ea8f4c31eac21cb0301
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxRtcFree
std::unordered_map<std::string, jobject> globalOpPropMap;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to give comments to indicate what those global variables for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

std::unordered_map<std::string, int> globalOpPropCountMap;
std::unordered_map<std::string, jobject> globalOpMap;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we have a better way to deal with the registry here, can we make the jobject, i.e., make the user defined class a closure, like what python code does ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I actually tried to implement is that because when those callback function was called, it was called in C++ side, then inside the callback function it needs to find the specific java object then
it can call the java object function, so I store the object in a map during initialization, then retrieve them with their name.

std::mutex mutex_opprop;
std::mutex mutex_op;

template<typename T>
class Ref {
Expand All @@ -1743,8 +1745,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister

const char *regName = env->GetStringUTFChars(jregName, 0);
std::string key(regName);

std::unique_lock<std::mutex> lock(mutex_opprop);
globalOpPropMap.insert({ key, env->NewGlobalRef(jopProp) });
globalOpPropCountMap.insert({ key, 0 });
lock.unlock();

auto creatorLambda = [](const char *opType, const int numKwargs,
const char **keys, const char **values, CustomOpPropInfo *ret) {
Expand Down Expand Up @@ -2034,7 +2039,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister
}
delete[] ts;

std::unique_lock<std::mutex> lock(mutex_op);
globalOpMap.insert({ key, env->NewGlobalRef(jOp) });
lock.unlock();

_jvm->DetachCurrentThread();

Expand Down Expand Up @@ -2129,6 +2136,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister
auto delEntry = [](void *state) {
std::string key((char *)state);
bool success = true;
std::unique_lock<std::mutex> lock(mutex_op);
if (globalOpMap.find(key) == globalOpMap.end()) {
LOG(FATAL) << "op: " << key << " not found";
success = false;
Expand All @@ -2142,6 +2150,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister
else ++it;
}
}
lock.unlock();
return success;
};

Expand All @@ -2159,6 +2168,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister
// del callback
auto opPropDel = [](void *state) {
std::string key((char *)state);
std::unique_lock<std::mutex> lock(mutex_opprop);
int count_prop = globalOpPropCountMap.at(key);
if (count_prop < 2) {
globalOpPropCountMap[key] = ++count_prop;
Expand All @@ -2182,6 +2192,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister
else ++it;
}
}
lock.unlock();
return success;
};

Expand Down