Skip to content

Commit 34f11ae

Browse files
authored
Support STORE and EXECUTE commands (#36)
1 parent f0c9dfd commit 34f11ae

File tree

11 files changed

+671
-32
lines changed

11 files changed

+671
-32
lines changed

pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@
5757
<version>3.6.0</version>
5858
<scope>compile</scope>
5959
</dependency>
60+
<dependency>
61+
<groupId>commons-io</groupId>
62+
<artifactId>commons-io</artifactId>
63+
<version>2.8.0</version>
64+
<scope>test</scope>
65+
</dependency>
6066
</dependencies>
6167

6268
<distributionManagement>

src/main/java/com/redislabs/redisai/Backend.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ public enum Backend implements ProtocolCommand {
1515
raw = SafeEncoder.encode(this.name());
1616
}
1717

18+
@Override
1819
public byte[] getRaw() {
1920
return raw;
2021
}

src/main/java/com/redislabs/redisai/Command.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@ public enum Command implements ProtocolCommand {
88
TENSOR_SET("AI.TENSORSET"),
99
MODEL_GET("AI.MODELGET"),
1010
MODEL_SET("AI.MODELSET"),
11+
MODEL_STORE("AI.MODELSTORE"),
1112
MODEL_DEL("AI.MODELDEL"),
1213
MODEL_RUN("AI.MODELRUN"),
14+
MODEL_EXECUTE("AI.MODELEXECUTE"),
1315
SCRIPT_SET("AI.SCRIPTSET"),
1416
SCRIPT_GET("AI.SCRIPTGET"),
1517
SCRIPT_DEL("AI.SCRIPTDEL"),
1618
SCRIPT_RUN("AI.SCRIPTRUN"),
17-
// TODO: support AI.DAGRUN
1819
DAGRUN("AI.DAGRUN"),
19-
// TODO: support AI.DAGRUN_RO
2020
DAGRUN_RO("AI.DAGRUN_RO"),
21+
DAGEXECUTE("AI.DAGEXECUTE"),
22+
DAGEXECUTE_RO("AI.DAGEXECUTE_RO"),
2123
INFO("AI.INFO"),
2224
CONFIG("AI.CONFIG");
2325

@@ -27,6 +29,7 @@ public enum Command implements ProtocolCommand {
2729
raw = SafeEncoder.encode(alt);
2830
}
2931

32+
@Override
3033
public byte[] getRaw() {
3134
return raw;
3235
}

src/main/java/com/redislabs/redisai/Dag.java

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@ public Dag() {}
1414
protected List<?> processDagReply(List<?> reply) {
1515
List<Object> outputList = new ArrayList<>(reply.size());
1616
for (int i = 0; i < reply.size(); i++) {
17-
if (this.tensorgetflag.get(i)) {
18-
outputList.add(Tensor.createTensorFromRespReply((List<?>) reply.get(i)));
17+
Object obj = reply.get(i);
18+
// TODO: Should encode 'OK', 'NA', etc. response
19+
if (obj instanceof Exception) {
20+
Exception ex = (Exception) obj;
21+
outputList.add(new RedisAIException(ex.getMessage(), ex));
22+
} else if (this.tensorgetflag.get(i)) {
23+
outputList.add(Tensor.createTensorFromRespReply((List<?>) obj));
1924
} else {
20-
outputList.add(reply.get(i));
25+
outputList.add(obj);
2126
}
2227
}
2328
return outputList;
@@ -47,6 +52,14 @@ public Dag runModel(String key, String[] inputs, String[] outputs) {
4752
return this;
4853
}
4954

55+
@Override
56+
public Dag executeModel(String key, String[] inputs, String[] outputs) {
57+
List<byte[]> args = Model.modelExecuteCommandArgs(key, inputs, outputs, -1L, true);
58+
this.commands.add(args);
59+
this.tensorgetflag.add(false);
60+
return this;
61+
}
62+
5063
@Override
5164
public Dag runScript(String key, String function, String[] inputs, String[] outputs) {
5265
List<byte[]> args = Script.scriptRunFlatArgs(key, function, inputs, outputs, true);
@@ -77,4 +90,37 @@ List<byte[]> dagRunFlatArgs(String[] loadKeys, String[] persistKeys) {
7790
}
7891
return args;
7992
}
93+
94+
List<byte[]> dagExecuteFlatArgs(
95+
String[] loadTensors, String[] persistTensors, String[] keysArgs) {
96+
List<byte[]> args = new ArrayList<>();
97+
if (loadTensors != null && loadTensors.length > 0) {
98+
args.add(Keyword.LOAD.getRaw());
99+
args.add(SafeEncoder.encode(String.valueOf(loadTensors.length)));
100+
for (String key : loadTensors) {
101+
args.add(SafeEncoder.encode(key));
102+
}
103+
}
104+
if (persistTensors != null && persistTensors.length > 0) {
105+
args.add(Keyword.PERSIST.getRaw());
106+
args.add(SafeEncoder.encode(String.valueOf(persistTensors.length)));
107+
for (String key : persistTensors) {
108+
args.add(SafeEncoder.encode(key));
109+
}
110+
}
111+
112+
if (keysArgs != null && keysArgs.length > 0) {
113+
args.add(Keyword.KEYS.getRaw());
114+
args.add(SafeEncoder.encode(String.valueOf(keysArgs.length)));
115+
for (String key : keysArgs) {
116+
args.add(SafeEncoder.encode(key));
117+
}
118+
}
119+
120+
for (List<byte[]> command : this.commands) {
121+
args.add(Keyword.PIPE.getRaw());
122+
args.addAll(command);
123+
}
124+
return args;
125+
}
80126
}

src/main/java/com/redislabs/redisai/DagRunCommands.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ interface DagRunCommands<T> {
77

88
T runModel(String key, String[] inputs, String[] outputs);
99

10+
T executeModel(String key, String[] inputs, String[] outputs);
11+
1012
T runScript(String key, String function, String[] inputs, String[] outputs);
1113
}

src/main/java/com/redislabs/redisai/Keyword.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ public enum Keyword implements ProtocolCommand {
1414
TAG,
1515
BATCHSIZE,
1616
MINBATCHSIZE,
17+
MINBATCHTIMEOUT,
18+
TIMEOUT,
1719
BACKENDSPATH,
1820
LOADBACKEND,
1921
LOAD,
2022
PERSIST,
23+
KEYS,
2124
PIPE("|>");
2225

2326
private final byte[] raw;

src/main/java/com/redislabs/redisai/Model.java

Lines changed: 124 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
/** Direct mapping to RedisAI Model */
1010
public class Model {
11+
1112
private Backend backend;
1213
private Device device;
1314
private String[] inputs;
@@ -16,6 +17,18 @@ public class Model {
1617
private String tag;
1718
private long batchSize;
1819
private long minBatchSize;
20+
private long minBatchTimeout;
21+
22+
/**
23+
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
24+
* @param device - the device that will execute the model. can be of CPU or GPU
25+
* @param blob - the Protobuf-serialized model
26+
*/
27+
public Model(Backend backend, Device device, byte[] blob) {
28+
this.backend = backend;
29+
this.device = device;
30+
this.blob = blob;
31+
}
1932

2033
/**
2134
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
@@ -63,13 +76,13 @@ public Model(
6376
}
6477

6578
public static Model createModelFromRespReply(List<?> reply) {
66-
Model model = null;
6779
Backend backend = null;
6880
Device device = null;
6981
String tag = null;
7082
byte[] blob = null;
7183
long batchsize = 0;
7284
long minbatchsize = 0;
85+
long minbatchtimeout = 0;
7386
String[] inputs = new String[0];
7487
String[] outputs = new String[0];
7588
for (int i = 0; i < reply.size(); i += 2) {
@@ -101,6 +114,9 @@ public static Model createModelFromRespReply(List<?> reply) {
101114
case "minbatchsize":
102115
minbatchsize = (Long) reply.get(i + 1);
103116
break;
117+
case "minbatchtimeout":
118+
minbatchtimeout = (Long) reply.get(i + 1);
119+
break;
104120
case "inputs":
105121
List<byte[]> inputsEncoded = (List<byte[]>) reply.get(i + 1);
106122
if (!inputsEncoded.isEmpty()) {
@@ -123,23 +139,27 @@ public static Model createModelFromRespReply(List<?> reply) {
123139
break;
124140
}
125141
}
142+
126143
if (backend == null || device == null || blob == null) {
127144
throw new JRedisAIRunTimeException(
128145
"AI.MODELGET reply did not contained all elements to build the model");
129146
}
130-
model = new Model(backend, device, inputs, outputs, blob, batchsize, minbatchsize);
131-
if (tag != null) {
132-
model.setTag(tag);
133-
}
134-
return model;
147+
return new Model(backend, device, blob)
148+
.setInputs(inputs)
149+
.setOutputs(outputs)
150+
.setBatchSize(batchsize)
151+
.setMinBatchSize(minbatchsize)
152+
.setMinBatchTimeout(minbatchtimeout)
153+
.setTag(tag);
135154
}
136155

137156
public String getTag() {
138157
return tag;
139158
}
140159

141-
public void setTag(String tag) {
160+
public Model setTag(String tag) {
142161
this.tag = tag;
162+
return this;
143163
}
144164

145165
public byte[] getBlob() {
@@ -154,16 +174,18 @@ public String[] getOutputs() {
154174
return outputs;
155175
}
156176

157-
public void setOutputs(String[] outputs) {
177+
public Model setOutputs(String[] outputs) {
158178
this.outputs = outputs;
179+
return this;
159180
}
160181

161182
public String[] getInputs() {
162183
return inputs;
163184
}
164185

165-
public void setInputs(String[] inputs) {
186+
public Model setInputs(String[] inputs) {
166187
this.inputs = inputs;
188+
return this;
167189
}
168190

169191
public Device getDevice() {
@@ -186,16 +208,27 @@ public long getBatchSize() {
186208
return batchSize;
187209
}
188210

189-
public void setBatchSize(long batchsize) {
211+
public Model setBatchSize(long batchsize) {
190212
this.batchSize = batchsize;
213+
return this;
191214
}
192215

193216
public long getMinBatchSize() {
194217
return minBatchSize;
195218
}
196219

197-
public void setMinBatchSize(long minbatchsize) {
220+
public Model setMinBatchSize(long minbatchsize) {
198221
this.minBatchSize = minbatchsize;
222+
return this;
223+
}
224+
225+
public long getMinBatchTimeout() {
226+
return minBatchTimeout;
227+
}
228+
229+
public Model setMinBatchTimeout(long minBatchTimeout) {
230+
this.minBatchTimeout = minBatchTimeout;
231+
return this;
199232
}
200233

201234
/**
@@ -234,6 +267,58 @@ protected List<byte[]> getModelSetCommandBytes(String key) {
234267
return args;
235268
}
236269

270+
/**
271+
* Encodes the current model properties into an AI.MODELSTORE command to store in RedisAI Server.
272+
*
273+
* @param key
274+
* @return
275+
*/
276+
protected List<byte[]> getModelStoreCommandArgs(String key) {
277+
278+
List<byte[]> args = new ArrayList<>();
279+
args.add(SafeEncoder.encode(key));
280+
281+
args.add(backend.getRaw());
282+
args.add(device.getRaw());
283+
284+
if (tag != null) {
285+
args.add(Keyword.TAG.getRaw());
286+
args.add(SafeEncoder.encode(tag));
287+
}
288+
289+
if (batchSize > 0) {
290+
args.add(Keyword.BATCHSIZE.getRaw());
291+
args.add(Protocol.toByteArray(batchSize));
292+
293+
args.add(Keyword.MINBATCHSIZE.getRaw());
294+
args.add(Protocol.toByteArray(minBatchSize));
295+
296+
args.add(Keyword.MINBATCHTIMEOUT.getRaw());
297+
args.add(Protocol.toByteArray(minBatchTimeout));
298+
}
299+
300+
if (inputs != null && inputs.length > 0) {
301+
args.add(Keyword.INPUTS.getRaw());
302+
args.add(Protocol.toByteArray(inputs.length));
303+
for (String input : inputs) {
304+
args.add(SafeEncoder.encode(input));
305+
}
306+
}
307+
308+
if (outputs != null && outputs.length > 0) {
309+
args.add(Keyword.OUTPUTS.getRaw());
310+
args.add(Protocol.toByteArray(outputs.length));
311+
for (String output : outputs) {
312+
args.add(SafeEncoder.encode(output));
313+
}
314+
}
315+
316+
args.add(Keyword.BLOB.getRaw());
317+
args.add(blob);
318+
319+
return args;
320+
}
321+
237322
protected static List<byte[]> modelRunFlatArgs(
238323
String key, String[] inputs, String[] outputs, boolean includeCommandName) {
239324
List<byte[]> args = new ArrayList<>();
@@ -253,4 +338,32 @@ protected static List<byte[]> modelRunFlatArgs(
253338
}
254339
return args;
255340
}
341+
342+
protected static List<byte[]> modelExecuteCommandArgs(
343+
String key, String[] inputs, String[] outputs, long timeout, boolean includeCommandName) {
344+
345+
List<byte[]> args = new ArrayList<>();
346+
if (includeCommandName) {
347+
args.add(Command.MODEL_EXECUTE.getRaw());
348+
}
349+
args.add(SafeEncoder.encode(key));
350+
351+
args.add(Keyword.INPUTS.getRaw());
352+
args.add(Protocol.toByteArray(inputs.length));
353+
for (String input : inputs) {
354+
args.add(SafeEncoder.encode(input));
355+
}
356+
357+
args.add(Keyword.OUTPUTS.getRaw());
358+
args.add(Protocol.toByteArray(outputs.length));
359+
for (String output : outputs) {
360+
args.add(SafeEncoder.encode(output));
361+
}
362+
363+
if (timeout >= 0) {
364+
args.add(Keyword.TIMEOUT.getRaw());
365+
args.add(Protocol.toByteArray(timeout));
366+
}
367+
return args;
368+
}
256369
}

0 commit comments

Comments
 (0)