Skip to content

Commit 1878b55

Browse files
authored
Model by Path(URI) and with Chunking (#38)
* Model by Path(URI) and with Chunking * Address DeepSource and Deprecate future final(s) * default 512mb * add doc
1 parent 6ae0ef2 commit 1878b55

File tree

5 files changed

+171
-7
lines changed

5 files changed

+171
-7
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,10 @@ and
5858

5959
client.runModel("model", new String[] {"a", "b"}, new String[] {"c"});
6060
```
61+
62+
### Note
63+
64+
Since version `0.10.0`, the chunk size of model (blob) is set to 512mb (536870912 bytes) based on default Redis
65+
configuration. This behavior can be changed by `redisai.blob.chunkSize` system property at the beginning of the
66+
application. For example, chunk size can be limited to 8mb by setting `-Dredisai.blob.chunkSize=8388608` or
67+
`System.setProperty(Model.BLOB_CHUNK_SIZE_PROPERTY, "8388608");`. A limit of 0 (zero) would disable chunking.

src/main/java/com/redislabs/redisai/Model.java

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,47 @@
11
package com.redislabs.redisai;
22

33
import com.redislabs.redisai.exceptions.JRedisAIRunTimeException;
4+
import java.io.IOException;
5+
import java.net.URI;
6+
import java.nio.file.Files;
7+
import java.nio.file.Paths;
48
import java.util.ArrayList;
9+
import java.util.Arrays;
510
import java.util.List;
611
import redis.clients.jedis.Protocol;
712
import redis.clients.jedis.util.SafeEncoder;
813

914
/** Direct mapping to RedisAI Model */
1015
public class Model {
1116

12-
private Backend backend;
13-
private Device device;
17+
public static final String BLOB_CHUNK_SIZE_PROPERTY = "redisai.blob.chunkSize";
18+
19+
private static final int BLOB_CHUNK_SIZE =
20+
Integer.parseInt(System.getProperty(BLOB_CHUNK_SIZE_PROPERTY, "536870912"));
21+
22+
private Backend backend; // TODO: final
23+
private Device device; // TODO: final
1424
private String[] inputs;
1525
private String[] outputs;
16-
private byte[] blob;
26+
private byte[] blob; // TODO: final
1727
private String tag;
1828
private long batchSize;
1929
private long minBatchSize;
2030
private long minBatchTimeout;
2131

32+
/**
33+
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
34+
* @param device - the device that will execute the model. can be of CPU or GPU
35+
* @param modelUri - filepath of the Protobuf-serialized model
36+
* @throws java.io.IOException
37+
* @see #Model(com.redislabs.redisai.Backend, com.redislabs.redisai.Device, byte[])
38+
* @see Files#readAllBytes(java.nio.file.Path)
39+
* @see Paths#get(java.net.URI)
40+
*/
41+
public Model(Backend backend, Device device, URI modelUri) throws IOException {
42+
this(backend, device, Files.readAllBytes(Paths.get(modelUri)));
43+
}
44+
2245
/**
2346
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
2447
* @param device - the device that will execute the model. can be of CPU or GPU
@@ -166,6 +189,11 @@ public byte[] getBlob() {
166189
return blob;
167190
}
168191

192+
/**
193+
* @param blob
194+
* @deprecated This variable will be final. Use any constructor.
195+
*/
196+
@Deprecated
169197
public void setBlob(byte[] blob) {
170198
this.blob = blob;
171199
}
@@ -192,6 +220,11 @@ public Device getDevice() {
192220
return device;
193221
}
194222

223+
/**
224+
* @param device
225+
* @deprecated This variable will be final. Use any constructor.
226+
*/
227+
@Deprecated
195228
public void setDevice(Device device) {
196229
this.device = device;
197230
}
@@ -200,6 +233,11 @@ public Backend getBackend() {
200233
return backend;
201234
}
202235

236+
/**
237+
* @param backend
238+
* @deprecated This variable will be final. Use any constructor.
239+
*/
240+
@Deprecated
203241
public void setBackend(Backend backend) {
204242
this.backend = backend;
205243
}
@@ -314,11 +352,26 @@ protected List<byte[]> getModelStoreCommandArgs(String key) {
314352
}
315353

316354
args.add(Keyword.BLOB.getRaw());
317-
args.add(blob);
355+
collectChunks(args, blob);
318356

319357
return args;
320358
}
321359

360+
private static void collectChunks(List<byte[]> collector, byte[] array) {
361+
final int chunkSize = BLOB_CHUNK_SIZE;
362+
if (chunkSize <= 0 || array.length <= chunkSize) {
363+
collector.add(array);
364+
return;
365+
}
366+
367+
int from = 0;
368+
while (from < array.length) {
369+
int copySize = Math.min(array.length - from, chunkSize);
370+
collector.add(Arrays.copyOfRange(array, from, from + copySize));
371+
from += copySize;
372+
}
373+
}
374+
322375
protected static List<byte[]> modelRunFlatArgs(
323376
String key, String[] inputs, String[] outputs, boolean includeCommandName) {
324377
List<byte[]> args = new ArrayList<>();
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package com.redislabs.redisai;
2+
3+
import java.io.IOException;
4+
import java.net.URISyntaxException;
5+
import org.junit.AfterClass;
6+
import org.junit.Assert;
7+
import org.junit.BeforeClass;
8+
import org.junit.Ignore;
9+
import org.junit.Test;
10+
11+
@Ignore
12+
public class ChunkTest {
13+
14+
private static final int SMALL_CHUNK_SIZE = 8 * 1024; // 8KB
15+
16+
@BeforeClass
17+
public static void prepare() {
18+
System.setProperty(Model.BLOB_CHUNK_SIZE_PROPERTY, Integer.toString(SMALL_CHUNK_SIZE));
19+
}
20+
21+
@AfterClass
22+
public static void cleanUp() {
23+
System.clearProperty(Model.BLOB_CHUNK_SIZE_PROPERTY);
24+
}
25+
26+
/**
27+
* @throws java.net.URISyntaxException
28+
* @throws java.io.IOException
29+
* @see ModelTest#argumentsWithoutChunking()
30+
*/
31+
@Test
32+
public void argumentsWithChunking() throws URISyntaxException, IOException {
33+
Model model =
34+
new Model(
35+
Backend.ONNX,
36+
Device.GPU,
37+
getClass().getClassLoader().getResource("test_data/mnist.onnx").toURI());
38+
39+
Assert.assertEquals(8, model.getModelStoreCommandArgs("key").size());
40+
}
41+
42+
@Test
43+
public void commandWithChunking() throws IOException, URISyntaxException {
44+
Model model =
45+
new Model(
46+
Backend.ONNX,
47+
Device.CPU,
48+
getClass().getClassLoader().getResource("test_data/mnist.onnx").toURI());
49+
50+
try (RedisAI ai = new RedisAI()) {
51+
Assert.assertTrue(ai.storeModel("model-chunk", model));
52+
Assert.assertNotNull(ai.getModel("model-chunk"));
53+
}
54+
}
55+
}

src/test/java/com/redislabs/redisai/ModelTest.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.redislabs.redisai;
22

33
import com.redislabs.redisai.exceptions.JRedisAIRunTimeException;
4+
import java.io.IOException;
5+
import java.net.URISyntaxException;
46
import java.util.List;
57
import org.junit.Assert;
68
import org.junit.Before;
@@ -37,11 +39,11 @@ public void getSetBlob() {
3739
byte[] expected = new byte[0];
3840
Model model = new Model(Backend.ONNX, Device.GPU, new String[0], new String[0], expected);
3941
byte[] blob = model.getBlob();
40-
Assert.assertEquals(blob, expected);
42+
Assert.assertSame(blob, expected);
4143
byte[] expected2 = new byte[] {0x10};
4244
model.setBlob(expected2);
4345
blob = model.getBlob();
44-
Assert.assertEquals(blob, expected2);
46+
Assert.assertSame(blob, expected2);
4547
}
4648

4749
@Test
@@ -104,6 +106,22 @@ public void getSetMinBatchSize() {
104106
Assert.assertEquals(10, minbatchsize);
105107
}
106108

109+
/**
110+
* @throws java.net.URISyntaxException
111+
* @throws java.io.IOException
112+
* @see ChunkTest#argumentsWithChunking()
113+
*/
114+
@Test
115+
public void argumentsWithoutChunking() throws URISyntaxException, IOException {
116+
Model model =
117+
new Model(
118+
Backend.ONNX,
119+
Device.GPU,
120+
getClass().getClassLoader().getResource("test_data/mnist.onnx").toURI());
121+
122+
Assert.assertEquals(5, model.getModelStoreCommandArgs("key").size());
123+
}
124+
107125
@Test
108126
public void createModelFromRespReply() {
109127
// negative testing

src/test/java/com/redislabs/redisai/RedisAITest.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.redislabs.redisai;
22

33
import java.io.IOException;
4+
import java.net.URISyntaxException;
5+
import java.net.URL;
46
import java.nio.file.Files;
57
import java.nio.file.Paths;
68
import java.util.Map;
@@ -256,8 +258,12 @@ public void storeModel() throws IOException {
256258
String[] outputs = new String[] {"mul"};
257259
byte[] blob = IOUtils.resourceToByteArray("test_data/graph.pb", getClass().getClassLoader());
258260

259-
Model createdModel = new Model(Backend.TF, Device.CPU, inputs, outputs, blob);
261+
Model createdModel = new Model(Backend.TF, Device.CPU, blob);
262+
createdModel.setInputs(inputs);
263+
createdModel.setOutputs(outputs);
264+
260265
Assert.assertTrue(client.storeModel("model-1", createdModel));
266+
261267
Model readModel1 = client.getModel("model-1");
262268
Assert.assertEquals(Backend.TF, readModel1.getBackend());
263269
Assert.assertEquals(Device.CPU, readModel1.getDevice());
@@ -272,7 +278,9 @@ public void storeModel() throws IOException {
272278
createdModel.setMinBatchSize(5);
273279
createdModel.setMinBatchTimeout(15);
274280
createdModel.setTag("test batching params");
281+
275282
Assert.assertTrue(client.storeModel("model-2", createdModel));
283+
276284
Model readModel2 = client.getModel("model-2");
277285
Assert.assertEquals(Backend.TF, readModel2.getBackend());
278286
Assert.assertEquals(Device.CPU, readModel2.getDevice());
@@ -284,6 +292,29 @@ public void storeModel() throws IOException {
284292
Assert.assertEquals("test batching params", readModel2.getTag());
285293
}
286294

295+
@Test
296+
public void storeModelByPath() throws URISyntaxException, IOException {
297+
String[] inputs = new String[] {"a", "b"};
298+
String[] outputs = new String[] {"mul"};
299+
URL modelUrl = getClass().getClassLoader().getResource("test_data/graph.pb");
300+
301+
Model createdModel = new Model(Backend.TF, Device.CPU, modelUrl.toURI());
302+
createdModel.setInputs(inputs);
303+
createdModel.setOutputs(outputs);
304+
305+
Assert.assertTrue(client.storeModel("model-uri", createdModel));
306+
307+
Model readModel = client.getModel("model-uri");
308+
Assert.assertEquals(Backend.TF, readModel.getBackend());
309+
Assert.assertEquals(Device.CPU, readModel.getDevice());
310+
Assert.assertArrayEquals(inputs, readModel.getInputs());
311+
Assert.assertArrayEquals(outputs, readModel.getOutputs());
312+
Assert.assertEquals(0L, readModel.getBatchSize());
313+
Assert.assertEquals(0L, readModel.getMinBatchSize());
314+
Assert.assertEquals(0L, readModel.getMinBatchTimeout());
315+
Assert.assertEquals("", readModel.getTag());
316+
}
317+
287318
@Test
288319
public void storeModelFail() throws IOException {
289320
try {

0 commit comments

Comments
 (0)