8
8
9
9
/** Direct mapping to RedisAI Model */
10
10
public class Model {
11
+
11
12
private Backend backend ;
12
13
private Device device ;
13
14
private String [] inputs ;
@@ -16,6 +17,18 @@ public class Model {
16
17
private String tag ;
17
18
private long batchSize ;
18
19
private long minBatchSize ;
20
+ private long minBatchTimeout ;
21
+
22
+ /**
23
+ * @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
24
+ * @param device - the device that will execute the model. can be of CPU or GPU
25
+ * @param blob - the Protobuf-serialized model
26
+ */
27
+ public Model (Backend backend , Device device , byte [] blob ) {
28
+ this .backend = backend ;
29
+ this .device = device ;
30
+ this .blob = blob ;
31
+ }
19
32
20
33
/**
21
34
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
@@ -63,13 +76,13 @@ public Model(
63
76
}
64
77
65
78
public static Model createModelFromRespReply (List <?> reply ) {
66
- Model model = null ;
67
79
Backend backend = null ;
68
80
Device device = null ;
69
81
String tag = null ;
70
82
byte [] blob = null ;
71
83
long batchsize = 0 ;
72
84
long minbatchsize = 0 ;
85
+ long minbatchtimeout = 0 ;
73
86
String [] inputs = new String [0 ];
74
87
String [] outputs = new String [0 ];
75
88
for (int i = 0 ; i < reply .size (); i += 2 ) {
@@ -101,6 +114,9 @@ public static Model createModelFromRespReply(List<?> reply) {
101
114
case "minbatchsize" :
102
115
minbatchsize = (Long ) reply .get (i + 1 );
103
116
break ;
117
+ case "minbatchtimeout" :
118
+ minbatchtimeout = (Long ) reply .get (i + 1 );
119
+ break ;
104
120
case "inputs" :
105
121
List <byte []> inputsEncoded = (List <byte []>) reply .get (i + 1 );
106
122
if (!inputsEncoded .isEmpty ()) {
@@ -123,23 +139,27 @@ public static Model createModelFromRespReply(List<?> reply) {
123
139
break ;
124
140
}
125
141
}
142
+
126
143
if (backend == null || device == null || blob == null ) {
127
144
throw new JRedisAIRunTimeException (
128
145
"AI.MODELGET reply did not contained all elements to build the model" );
129
146
}
130
- model = new Model (backend , device , inputs , outputs , blob , batchsize , minbatchsize );
131
- if (tag != null ) {
132
- model .setTag (tag );
133
- }
134
- return model ;
147
+ return new Model (backend , device , blob )
148
+ .setInputs (inputs )
149
+ .setOutputs (outputs )
150
+ .setBatchSize (batchsize )
151
+ .setMinBatchSize (minbatchsize )
152
+ .setMinBatchTimeout (minbatchtimeout )
153
+ .setTag (tag );
135
154
}
136
155
137
156
public String getTag () {
138
157
return tag ;
139
158
}
140
159
141
- public void setTag (String tag ) {
160
+ public Model setTag (String tag ) {
142
161
this .tag = tag ;
162
+ return this ;
143
163
}
144
164
145
165
public byte [] getBlob () {
@@ -154,16 +174,18 @@ public String[] getOutputs() {
154
174
return outputs ;
155
175
}
156
176
157
- public void setOutputs (String [] outputs ) {
177
+ public Model setOutputs (String [] outputs ) {
158
178
this .outputs = outputs ;
179
+ return this ;
159
180
}
160
181
161
182
public String [] getInputs () {
162
183
return inputs ;
163
184
}
164
185
165
- public void setInputs (String [] inputs ) {
186
+ public Model setInputs (String [] inputs ) {
166
187
this .inputs = inputs ;
188
+ return this ;
167
189
}
168
190
169
191
public Device getDevice () {
@@ -186,16 +208,27 @@ public long getBatchSize() {
186
208
return batchSize ;
187
209
}
188
210
189
- public void setBatchSize (long batchsize ) {
211
+ public Model setBatchSize (long batchsize ) {
190
212
this .batchSize = batchsize ;
213
+ return this ;
191
214
}
192
215
193
216
public long getMinBatchSize () {
194
217
return minBatchSize ;
195
218
}
196
219
197
- public void setMinBatchSize (long minbatchsize ) {
220
+ public Model setMinBatchSize (long minbatchsize ) {
198
221
this .minBatchSize = minbatchsize ;
222
+ return this ;
223
+ }
224
+
225
+ public long getMinBatchTimeout () {
226
+ return minBatchTimeout ;
227
+ }
228
+
229
+ public Model setMinBatchTimeout (long minBatchTimeout ) {
230
+ this .minBatchTimeout = minBatchTimeout ;
231
+ return this ;
199
232
}
200
233
201
234
/**
@@ -234,6 +267,58 @@ protected List<byte[]> getModelSetCommandBytes(String key) {
234
267
return args ;
235
268
}
236
269
270
+ /**
271
+ * Encodes the current model properties into an AI.MODELSTORE command to store in RedisAI Server.
272
+ *
273
+ * @param key
274
+ * @return
275
+ */
276
+ protected List <byte []> getModelStoreCommandArgs (String key ) {
277
+
278
+ List <byte []> args = new ArrayList <>();
279
+ args .add (SafeEncoder .encode (key ));
280
+
281
+ args .add (backend .getRaw ());
282
+ args .add (device .getRaw ());
283
+
284
+ if (tag != null ) {
285
+ args .add (Keyword .TAG .getRaw ());
286
+ args .add (SafeEncoder .encode (tag ));
287
+ }
288
+
289
+ if (batchSize > 0 ) {
290
+ args .add (Keyword .BATCHSIZE .getRaw ());
291
+ args .add (Protocol .toByteArray (batchSize ));
292
+
293
+ args .add (Keyword .MINBATCHSIZE .getRaw ());
294
+ args .add (Protocol .toByteArray (minBatchSize ));
295
+
296
+ args .add (Keyword .MINBATCHTIMEOUT .getRaw ());
297
+ args .add (Protocol .toByteArray (minBatchTimeout ));
298
+ }
299
+
300
+ if (inputs != null && inputs .length > 0 ) {
301
+ args .add (Keyword .INPUTS .getRaw ());
302
+ args .add (Protocol .toByteArray (inputs .length ));
303
+ for (String input : inputs ) {
304
+ args .add (SafeEncoder .encode (input ));
305
+ }
306
+ }
307
+
308
+ if (outputs != null && outputs .length > 0 ) {
309
+ args .add (Keyword .OUTPUTS .getRaw ());
310
+ args .add (Protocol .toByteArray (outputs .length ));
311
+ for (String output : outputs ) {
312
+ args .add (SafeEncoder .encode (output ));
313
+ }
314
+ }
315
+
316
+ args .add (Keyword .BLOB .getRaw ());
317
+ args .add (blob );
318
+
319
+ return args ;
320
+ }
321
+
237
322
protected static List <byte []> modelRunFlatArgs (
238
323
String key , String [] inputs , String [] outputs , boolean includeCommandName ) {
239
324
List <byte []> args = new ArrayList <>();
@@ -253,4 +338,32 @@ protected static List<byte[]> modelRunFlatArgs(
253
338
}
254
339
return args ;
255
340
}
341
+
342
+ protected static List <byte []> modelExecuteCommandArgs (
343
+ String key , String [] inputs , String [] outputs , long timeout , boolean includeCommandName ) {
344
+
345
+ List <byte []> args = new ArrayList <>();
346
+ if (includeCommandName ) {
347
+ args .add (Command .MODEL_EXECUTE .getRaw ());
348
+ }
349
+ args .add (SafeEncoder .encode (key ));
350
+
351
+ args .add (Keyword .INPUTS .getRaw ());
352
+ args .add (Protocol .toByteArray (inputs .length ));
353
+ for (String input : inputs ) {
354
+ args .add (SafeEncoder .encode (input ));
355
+ }
356
+
357
+ args .add (Keyword .OUTPUTS .getRaw ());
358
+ args .add (Protocol .toByteArray (outputs .length ));
359
+ for (String output : outputs ) {
360
+ args .add (SafeEncoder .encode (output ));
361
+ }
362
+
363
+ if (timeout >= 0 ) {
364
+ args .add (Keyword .TIMEOUT .getRaw ());
365
+ args .add (Protocol .toByteArray (timeout ));
366
+ }
367
+ return args ;
368
+ }
256
369
}
0 commit comments