Skip to content

Commit b6b9871

Browse files
authored
New Script commands (#50)
* Modify Script * SCRIPTSTORE * SCRIPTGET * SCRIPTEXECUTE * dag * missed args
1 parent 150ee12 commit b6b9871

File tree

10 files changed

+292
-47
lines changed

10 files changed

+292
-47
lines changed

src/main/java/com/redislabs/redisai/Command.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ public enum Command implements ProtocolCommand {
1313
MODEL_RUN("AI.MODELRUN"),
1414
MODEL_EXECUTE("AI.MODELEXECUTE"),
1515
SCRIPT_SET("AI.SCRIPTSET"),
16+
SCRIPT_STORE("AI.SCRIPTSTORE"),
1617
SCRIPT_GET("AI.SCRIPTGET"),
1718
SCRIPT_DEL("AI.SCRIPTDEL"),
1819
SCRIPT_RUN("AI.SCRIPTRUN"),
20+
SCRIPT_EXECUTE("AI.SCRIPTEXECUTE"),
1921
DAGRUN("AI.DAGRUN"),
2022
DAGRUN_RO("AI.DAGRUN_RO"),
2123
DAGEXECUTE("AI.DAGEXECUTE"),

src/main/java/com/redislabs/redisai/Dag.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ public Dag runScript(String key, String function, String[] inputs, String[] outp
6868
return this;
6969
}
7070

71+
@Override
72+
public Dag executeScript(
73+
String key,
74+
String function,
75+
List<String> keys,
76+
List<String> inputs,
77+
List<String> args,
78+
List<String> outputs) {
79+
List<byte[]> binary =
80+
Script.scriptExecuteFlatArgs(key, function, keys, inputs, args, outputs, -1L, true);
81+
this.commands.add(binary);
82+
this.tensorgetflag.add(false);
83+
return this;
84+
}
85+
7186
List<byte[]> dagRunFlatArgs(String[] loadKeys, String[] persistKeys) {
7287
List<byte[]> args = new ArrayList<>();
7388
if (loadKeys != null && loadKeys.length > 0) {

src/main/java/com/redislabs/redisai/DagRunCommands.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.redislabs.redisai;
22

3+
import java.util.List;
4+
35
interface DagRunCommands<T> {
46
T setTensor(String key, Tensor tensor);
57

@@ -10,4 +12,12 @@ interface DagRunCommands<T> {
1012
T executeModel(String key, String[] inputs, String[] outputs);
1113

1214
T runScript(String key, String function, String[] inputs, String[] outputs);
15+
16+
T executeScript(
17+
String key,
18+
String function,
19+
List<String> keys,
20+
List<String> inputs,
21+
List<String> args,
22+
List<String> outputs);
1323
}

src/main/java/com/redislabs/redisai/Device.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ public enum Device implements ProtocolCommand {
1010
private final byte[] raw;
1111

1212
Device() {
13-
raw = SafeEncoder.encode(this.name());
13+
raw = SafeEncoder.encode(name());
1414
}
1515

16+
@Override
1617
public byte[] getRaw() {
1718
return raw;
1819
}

src/main/java/com/redislabs/redisai/Keyword.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public enum Keyword implements ProtocolCommand {
1212
SOURCE,
1313
RESETSTAT,
1414
TAG,
15+
ENTRY_POINTS,
1516
BATCHSIZE,
1617
MINBATCHSIZE,
1718
MINBATCHTIMEOUT,
@@ -21,6 +22,7 @@ public enum Keyword implements ProtocolCommand {
2122
LOAD,
2223
PERSIST,
2324
KEYS,
25+
ARGS,
2426
PIPE("|>");
2527

2628
private final byte[] raw;

src/main/java/com/redislabs/redisai/RedisAI.java

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.Map;
1010
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
1111
import redis.clients.jedis.BinaryClient;
12+
import redis.clients.jedis.Client;
1213
import redis.clients.jedis.HostAndPort;
1314
import redis.clients.jedis.Jedis;
1415
import redis.clients.jedis.JedisClientConfig;
@@ -103,6 +104,22 @@ private static JedisPoolConfig initPoolConfig(int poolSize) {
103104
return conf;
104105
}
105106

107+
private Jedis getConnection() {
108+
return pool.getResource();
109+
}
110+
111+
private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) {
112+
BinaryClient client = conn.getClient();
113+
client.sendCommand(command, args);
114+
return client;
115+
}
116+
117+
private Client sendCommand(Jedis conn, Command command, String... args) {
118+
Client client = conn.getClient();
119+
client.sendCommand(command, args);
120+
return client;
121+
}
122+
106123
/**
107124
* Direct mapping to AI.TENSORSET
108125
*
@@ -320,6 +337,27 @@ public boolean setScript(String key, Script script) {
320337
}
321338
}
322339

340+
/**
341+
* Direct mapping to AI.MODELSTORE command.
342+
*
343+
* <p>{@code AI.SCRIPTSTORE <key> <device> [TAG tag] ENTRY_POINTS <entry_point_count>
344+
* <entry_point> [<entry_point>...] SOURCE "<script>"}
345+
*
346+
* @param key name of key to store the Script in RedisAI server
347+
* @param script the Script Object
348+
* @return true if Script was properly stored in RedisAI server
349+
*/
350+
public boolean storeScript(String key, Script script) {
351+
try (Jedis conn = getConnection()) {
352+
List<String> args = script.getScriptStoreCommandBytes(key);
353+
return sendCommand(conn, Command.SCRIPT_STORE, args.toArray(new String[args.size()]))
354+
.getStatusCodeReply()
355+
.equals("OK");
356+
} catch (JedisDataException ex) {
357+
throw new RedisAIException(ex.getMessage(), ex);
358+
}
359+
}
360+
323361
/**
324362
* Direct mapping to AI.SCRIPTGET
325363
*
@@ -427,6 +465,50 @@ public boolean runScript(String key, String function, String[] inputs, String[]
427465
}
428466
}
429467

468+
public boolean executeScript(
469+
String key,
470+
String function,
471+
List<String> keys,
472+
List<String> inputs,
473+
List<String> args,
474+
List<String> outputs) {
475+
return executeScript(key, function, keys, inputs, args, outputs, -1);
476+
}
477+
478+
/**
479+
* Direct mapping to AI.SCRIPTEXECUTE command.
480+
*
481+
* <p>{@code AI.SCRIPTEXECUTE <key> <function> [KEYS n <key> [keys...]] [INPUTS m <input> [input
482+
* ...]] [ARGS k <arg> [arg...]] [OUTPUTS k <output> [output ...] [TIMEOUT t]]+}
483+
*
484+
* @param key
485+
* @param function
486+
* @param keys
487+
* @param inputs
488+
* @param args
489+
* @param outputs
490+
* @param timeout timeout in ms
491+
* @return
492+
*/
493+
public boolean executeScript(
494+
String key,
495+
String function,
496+
List<String> keys,
497+
List<String> inputs,
498+
List<String> args,
499+
List<String> outputs,
500+
long timeout) {
501+
try (Jedis conn = getConnection()) {
502+
List<byte[]> binary =
503+
Script.scriptExecuteFlatArgs(key, function, keys, inputs, args, outputs, timeout, false);
504+
return sendCommand(conn, Command.SCRIPT_EXECUTE, binary.toArray(new byte[binary.size()][]))
505+
.getStatusCodeReply()
506+
.equals("OK");
507+
} catch (JedisDataException ex) {
508+
throw new RedisAIException(ex.getMessage(), ex);
509+
}
510+
}
511+
430512
/**
431513
* Direct mapping to AI.DAGRUN specifies a direct acyclic graph of operations to run within
432514
* RedisAI
@@ -555,16 +637,6 @@ public boolean resetStat(String key) {
555637
}
556638
}
557639

558-
private Jedis getConnection() {
559-
return pool.getResource();
560-
}
561-
562-
private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) {
563-
BinaryClient client = conn.getClient();
564-
client.sendCommand(command, args);
565-
return client;
566-
}
567-
568640
/**
569641
* AI.CONFIG <BACKENDSPATH <path>>
570642
*

0 commit comments

Comments
 (0)