Skip to content

Commit 458933c

Browse files
authored
[api] Fixed NDList decode numpy file bug (#2804)
1 parent d432a65 commit 458933c

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

api/src/main/java/ai/djl/ndarray/NDList.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ public static NDList decode(NDManager manager, byte[] byteArray) {
100100
try {
101101
if (byteArray[0] == 'P' && byteArray[1] == 'K') {
102102
return decodeNumpy(manager, new ByteArrayInputStream(byteArray));
103-
} else if (byteArray[0] == (byte) 0x39
103+
} else if (byteArray[0] == (byte) 0x93
104104
&& byteArray[1] == 'N'
105105
&& byteArray[2] == 'U'
106106
&& byteArray[3] == 'M') {
107107
return new NDList(
108-
NDSerializer.decode(manager, new ByteArrayInputStream(byteArray)));
108+
NDSerializer.decodeNumpy(manager, new ByteArrayInputStream(byteArray)));
109109
} else if (byteArray[8] == '{') {
110110
return decodeSafetensors(manager, new ByteArrayInputStream(byteArray));
111111
}
@@ -144,11 +144,11 @@ public static NDList decode(NDManager manager, InputStream is) {
144144
if (magic[0] == 'P' && magic[1] == 'K') {
145145
// assume this is npz file
146146
return decodeNumpy(manager, pis);
147-
} else if (magic[0] == (byte) 0x39
147+
} else if (magic[0] == (byte) 0x93
148148
&& magic[1] == 'N'
149149
&& magic[2] == 'U'
150150
&& magic[3] == 'M') {
151-
return new NDList(NDSerializer.decode(manager, pis));
151+
return new NDList(NDSerializer.decodeNumpy(manager, pis));
152152
} else if (magic[8] == '{') {
153153
return decodeSafetensors(manager, pis);
154154
}

api/src/test/java/ai/djl/ndarray/NDSerializerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ private static byte[] encode(NDArray array) throws IOException {
107107

108108
private static NDArray decode(NDManager manager, byte[] data) throws IOException {
109109
try (ByteArrayInputStream bis = new ByteArrayInputStream(data)) {
110-
return NDSerializer.decodeNumpy(manager, bis);
110+
return NDList.decode(manager, bis).get(0);
111111
}
112112
}
113113

0 commit comments

Comments
 (0)