Skip to content

Commit 03bf47b

Browse files
author
feklin.v.v
committed
Added percentile implementation for Pytorch engine
1 parent a73f9e1 commit 03bf47b

File tree

5 files changed

+57
-2
lines changed

5 files changed

+57
-2
lines changed

engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,13 +1547,17 @@ public PtNDArray argMin(int axis) {
15471547
/** {@inheritDoc} */
15481548
@Override
15491549
public PtNDArray percentile(Number percentile) {
1550-
throw new UnsupportedOperationException("Not implemented");
1550+
return percentile(percentile, new int[] {-1});
15511551
}
15521552

15531553
/** {@inheritDoc} */
15541554
@Override
15551555
public PtNDArray percentile(Number percentile, int[] axes) {
1556-
throw new UnsupportedOperationException("Not implemented");
1556+
if (axes.length != 1) {
1557+
throw new UnsupportedOperationException(
1558+
"Not supporting zero or multi-dimension percentile");
1559+
}
1560+
return JniUtils.percentile(this, percentile, axes[0], false);
15571561
}
15581562

15591563
/** {@inheritDoc} */

engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,14 @@ public static NDList median(PtNDArray ndArray, long dim, boolean keepDim) {
888888
new PtNDArray(ndArray.getManager(), handles[1]));
889889
}
890890

891+
public static PtNDArray percentile(
892+
PtNDArray ndArray, Number percentile, long dim, boolean keepDim) {
893+
float quantile = percentile.floatValue() / 100;
894+
return new PtNDArray(
895+
ndArray.getManager(),
896+
PyTorchLibrary.LIB.torchQuantile(ndArray.getHandle(), quantile, dim, keepDim));
897+
}
898+
891899
public static PtNDArray mean(PtNDArray ndArray) {
892900
return new PtNDArray(
893901
ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle()));

engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ native void torchIndexPut(
248248

249249
native long[] torchMedian(long self, long dim, boolean keepDim);
250250

251+
native long torchQuantile(long self, float quantile, long dim, boolean keepDim);
252+
251253
native long torchMin(long handle);
252254

253255
native long torchMin(long handle, long dim, boolean keepDim);

engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMedian(
227227
API_END_RETURN()
228228
}
229229

230+
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchQuantile(
231+
JNIEnv* env, jobject jthis, jlong jself, jfloat q, jlong jdim, jboolean keep_dim) {
232+
API_BEGIN()
233+
const auto* self_ptr = reinterpret_cast<torch::Tensor*>(jself);
234+
const auto* result_ptr = new torch::Tensor(torch::quantile(*self_ptr, q, jdim, keep_dim));
235+
return reinterpret_cast<uintptr_t>(result_ptr);
236+
API_END_RETURN()
237+
}
238+
230239
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAbs(JNIEnv* env, jobject jthis, jlong jhandle) {
231240
API_BEGIN()
232241
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);

integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementComparisonOpTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,38 @@ public void testMedian() {
611611
}
612612
}
613613

614+
@Test
615+
public void testPercentile() {
616+
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
617+
NDArray array1 = manager.create(new float[] {1, 3, 2, 5, 4});
618+
Assert.assertEquals(array1.percentile(10), manager.create(1.4f));
619+
Assert.assertEquals(array1.percentile(90), manager.create(4.6f));
620+
621+
NDArray array2 =
622+
manager.create(
623+
new float[][] {
624+
{11, 12, 13, 14, 15},
625+
{21, 22, 23, 24, 25},
626+
{31, 32, 33, 34, 35},
627+
{41, 42, 43, 44, 45},
628+
{51, 52, 53, 54, 55}
629+
});
630+
Assert.assertEquals(
631+
array2.percentile(10, new int[] {1}),
632+
manager.create(new float[] {11.4f, 21.4f, 31.4f, 41.4f, 51.4f}));
633+
Assert.assertEquals(
634+
array2.percentile(90, new int[] {1}),
635+
manager.create(new float[] {14.6f, 24.6f, 34.6f, 44.6f, 54.6f}));
636+
637+
Assert.assertEquals(
638+
array2.percentile(10, new int[] {0}),
639+
manager.create(new float[] {15, 16, 17, 18, 19}));
640+
Assert.assertEquals(
641+
array2.percentile(90, new int[] {0}),
642+
manager.create(new float[] {47, 48, 49, 50, 51}));
643+
}
644+
}
645+
614646
@Test
615647
public void testWhere() {
616648
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {

0 commit comments

Comments
 (0)