Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Scala] add CustomOp support #4118

Merged
merged 11 commits into from
Dec 15, 2016
Merged

[Scala] add CustomOp support #4118

merged 11 commits into from
Dec 15, 2016

Conversation

Ldpe2G
Copy link
Contributor

@Ldpe2G Ldpe2G commented Dec 6, 2016

follows: https://github.com/dmlc/mxnet/blob/master/python/mxnet/operator.py
@Javelinjs Could you help to review the code?
Or anyone else help to review this code will be appreciated.
I am not sure whether I am implementating it right, for me this is the reasonable way I can come up with.

Copy link
Member

@yzhliu yzhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simply reviewed the code conventions. This is a great work. I'll look into and check the correctness and design latter tomorrow. My instant feel is that the C codes are a little bit complex, and may contain risk of memory leak.

* array of aux shapes calculated from in_shape,
* in the same order as declared in listAuxiliaryStates().
*/
def inferShape(inShape: Array[Array[Int]]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems you shall use Array[Shape] instead.

} catch {
case ex: Throwable => {
success = false
new Throwable().printStackTrace()
Copy link
Member

@yzhliu yzhliu Dec 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try ... catch ..., set success, return 0/1
hmm... this is somewhat wired

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or change to return boolean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ex.printStackTrace()

// set CustomOpProp.kwargs
std::string opPropKey(opType);
if (globalOpPropMap.find(opPropKey) == globalOpPropMap.end()) {
std::cout << "CustomOpProp: " << opPropKey << " not found" << std::endl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use dmlc::log ?

auto delEntry = [](void *state) {
std::string key((char *)state);
bool success = true;
// if (globalOpMap.find(key) == globalOpMap.end()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug codes?

Copy link
Contributor Author

@Ldpe2G Ldpe2G Dec 8, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I use global variable map to store the customop and customopprop object, and want to
remove them from the map and release them in the del fuction. But this delEntry function seems to be called during the execution, and cause seom problem , so I comment the code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure what happens here. ping @piiswrong


CustomOpPropCreator creator = static_cast<bool(*)(const char*, const int, const char**, const char**, CustomOpPropInfo*)>(creatorLambda);
int ret = MXCustomOpRegister(regName, creator);
return ret;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simply return MXCustomOpRegister(regName, creator);

if (obj != NULL) {
jobjectArray jauxs = (jobjectArray)obj;
int len = globalEnv->GetArrayLength(jauxs);
(*auxs) = new char *[len+1];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new but never delete ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the auxs pointer was used to store the result, and passed in by the caller

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but you shall not do in this way. We need to find a way to delete them finally, e.g., use an object like mxnet::common::ThreadLocalStore to manage it. Or, I'm not sure, is there a way to convert java String to char *, but let jvm handle the reference and automatically do GC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I take a look.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the memory release should be handle by the caller side? Because the caller call this function and pass a pointer for example auxs, then after the function invocation the auxs contains the results the caller needed, then if you handle the memory release inside the function, how do you know when you release the memory, the caller side won't need that anymore?

Copy link
Contributor Author

@Ldpe2G Ldpe2G Dec 10, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figure out a way to handle the memory release 64a65cc
, what's your opinion?

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxCustomOpRegister
(JNIEnv *env, jobject obj, jstring jregName, jobject jopProp) {

globalEnv = env;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls refer to KVStoreServerControllerFunc. I dont think what here does is correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why I use global variable here is because, if you want to use a local variable inside a c++ lambda,
this variable needs to be captured, then you can not convert this lambda to function pointer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I think might understand what you mean.

if (obj != NULL) {
jobjectArray jauxs = (jobjectArray)obj;
int len = globalEnv->GetArrayLength(jauxs);
(*auxs) = new char *[len+1];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but you shall not do in this way. We need to find a way to delete them finally, e.g., use an object like mxnet::common::ThreadLocalStore to manage it. Or, I'm not sure, is there a way to convert java String to char *, but let jvm handle the reference and automatically do GC?


JNIEnv *globalEnv;
std::unordered_map<std::string, jobject> globalOpPropMap;
std::unordered_map<std::string, jobject> globalOpMap;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we have a better way to deal with the registry here, can we make the jobject, i.e., make the user defined class a closure, like what python code does ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I actually tried to implement is that because when those callback function was called, it was called in C++ side, then inside the callback function it needs to find the specific java object then
it can call the java object function, so I store the object in a map during initialization, then retrieve them with their name.

auto delEntry = [](void *state) {
std::string key((char *)state);
bool success = true;
// if (globalOpMap.find(key) == globalOpMap.end()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure what happens here. ping @piiswrong

@Ldpe2G
Copy link
Contributor Author

Ldpe2G commented Dec 12, 2016

@Javelinjs I have tested customop with multithread, and the results on mnist dataset seems fine.

Copy link
Member

@yzhliu yzhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls also fix the code conventions, according to http://107.23.36.180:8080/job/mxnet/759/console

} catch {
case ex: Throwable => {
success = false
new Throwable().printStackTrace()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ex.printStackTrace()

@@ -1711,3 +1715,505 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxRtcFree
int ret = MXRtcFree(handle);
return ret;
}

std::unordered_map<std::string, jobject> globalOpPropMap;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to give comments to indicate what those global variables for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

// set CustomOpProp.kwargs
std::string opPropKey(opType);
if (globalOpPropMap.find(opPropKey) == globalOpPropMap.end()) {
LOG(FATAL) << "CustomOpProp: " << opPropKey << " not found";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(FATAL) will die immediately. Is it important to return a state in this function? if so, you may change the log level to WARN

jint *rdepsArr = env->GetIntArrayElements(jrdeps, NULL);

(*numDeps) = env->GetArrayLength(jrdeps);
(* rdeps) = new int[(* numDeps)];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove these parentheses.

tagsArr,
reqsArr,
*(const_cast<bool*>(&isTrain)));
if (result == false) success = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

success = env->call...

same for the others

(*auxs) = new char *[1];
(*auxs)[0] = NULL;
}
dmlc::ThreadLocalStore<Ref<char>>::Get()->setData(*auxs);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's remove this. So we have no chance to get rid of memory leak for such data transfer from jvm side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think so.

@Ldpe2G
Copy link
Contributor Author

Ldpe2G commented Dec 14, 2016

@Javelinjs I think I have done all the changes and made this pr satified the requirements. Is there anything else I missed?

@yzhliu yzhliu merged commit 87c6404 into apache:master Dec 15, 2016
@yzhliu
Copy link
Member

yzhliu commented Dec 15, 2016

Please also take care once this feature is merged to nnvm branch.

@Ldpe2G
Copy link
Contributor Author

Ldpe2G commented Dec 16, 2016

@Javelinjs ok

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants