Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 44 additions & 30 deletions swig/lightgbmlib.i
Original file line number Diff line number Diff line change
Expand Up @@ -74,41 +74,48 @@
return dst;
}

int LGBM_BoosterPredictForMatSingle(JNIEnv *jenv,
jdoubleArray data,
BoosterHandle handle,
int data_type,
int ncol,
int is_row_major,
int predict_type,
int num_iteration,
const char* parameter,
int64_t* out_len,
double* out_result) {
jdoubleArray LGBM_BoosterPredictForMatSingle(JNIEnv *jenv,
jdoubleArray data,
BoosterHandle handle,
int data_type,
int ncol,
int is_row_major,
int predict_type,
int num_iteration,
const char* parameter,
int64_t* out_len) {
double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0);
// Note: we allocate the output array in the native side but return new java array
double* out_result = new double[*out_len];

int ret = LGBM_BoosterPredictForMatSingleRow(handle, data0, data_type, ncol, is_row_major, predict_type,
num_iteration, parameter, out_len, out_result);

jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT);

if (ret != 0) {
return nullptr;
Copy link
Contributor Author

@imatiach-msft imatiach-msft Jun 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like I need to delete before returning here in case of error

delete[] out_result;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use

vector<double> data0;
data0.resize(*out_len);

this will get you RAII behavior for errors (exceptions,...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea

}

return ret;
jdoubleArray new_array = jenv->NewDoubleArray(*out_len);
jenv->SetDoubleArrayRegion(new_array, 0, *out_len, out_result);
delete[] out_result;
return new_array;
}

int LGBM_BoosterPredictForCSRSingle(JNIEnv *jenv,
jintArray indices,
jdoubleArray values,
int numNonZeros,
BoosterHandle handle,
int indptr_type,
int data_type,
int64_t nelem,
int64_t num_col,
int predict_type,
int num_iteration,
const char* parameter,
int64_t* out_len,
double* out_result) {
jdoubleArray LGBM_BoosterPredictForCSRSingle(JNIEnv *jenv,
jintArray indices,
jdoubleArray values,
int numNonZeros,
BoosterHandle handle,
int indptr_type,
int data_type,
int64_t nelem,
int64_t num_col,
int predict_type,
int num_iteration,
const char* parameter,
int64_t* out_len) {
// Alternatives
// - GetIntArrayElements: performs copy
// - GetDirectBufferAddress: fails on wrapped array
Expand All @@ -118,7 +125,8 @@
jboolean isCopy;
int* indices0 = (int*)jenv->GetPrimitiveArrayCritical(indices, &isCopy);
double* values0 = (double*)jenv->GetPrimitiveArrayCritical(values, &isCopy);

// Note: we allocate the output array in the native side but return new java array
double* out_result = new double[*out_len];
int32_t ind[2] = { 0, numNonZeros };

int ret = LGBM_BoosterPredictForCSRSingleRow(handle, ind, indptr_type, indices0, values0, data_type, 2,
Expand All @@ -127,7 +135,13 @@
jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT);
jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT);

return ret;
if (ret != 0) {
return nullptr;
}
jdoubleArray new_array = jenv->NewDoubleArray(*out_len);
jenv->SetDoubleArrayRegion(new_array, 0, *out_len, out_result);
delete[] out_result;
return new_array;
}

#include <functional>
Expand Down Expand Up @@ -243,7 +257,7 @@
static void delete_##NAME(TYPE *self) { %}
%{ if (self) delete self; %}
%{}
%}
%}

TYPE *new_##NAME();
void delete_##NAME(TYPE *self);
Expand Down Expand Up @@ -289,7 +303,7 @@ TYPE *NAME##_handle();
static void NAME##_setitem(TYPE *ary, int64_t index, TYPE value) {
ary[index] = value;
}
%}
%}

TYPE *new_##NAME(int64_t nelements);
void delete_##NAME(TYPE *ary);
Expand Down