|
5 | 5 |
|
6 | 6 | from redisai import command_builder as builder
|
7 | 7 | from redisai.postprocessor import Processor
|
| 8 | +from deprecated import deprecated |
| 9 | +import warnings |
8 | 10 |
|
9 | 11 | processor = Processor()
|
10 | 12 |
|
11 | 13 |
|
12 | 14 | class Dag:
|
13 |
| - def __init__(self, load, persist, executor, readonly=False, postprocess=True): |
| 15 | + def __init__(self, load, persist, routing, timeout, executor, readonly=False): |
14 | 16 | self.result_processors = []
|
15 |
| - self.enable_postprocess = postprocess |
16 |
| - if readonly: |
17 |
| - if persist: |
18 |
| - raise RuntimeError( |
19 |
| - "READONLY requests cannot write (duh!) and should not " |
20 |
| - "have PERSISTing values" |
21 |
| - ) |
22 |
| - self.commands = ["AI.DAGRUN_RO"] |
| 17 | + self.enable_postprocess = True |
| 18 | + self.deprecatedDagrunMode = load is None and persist is None and routing is None |
| 19 | + self.readonly = readonly |
| 20 | + self.executor = executor |
| 21 | + |
| 22 | + if readonly and persist: |
| 23 | + raise RuntimeError( |
| 24 | + "READONLY requests cannot write (duh!) and should not " |
| 25 | + "have PERSISTing values" |
| 26 | + ) |
| 27 | + |
| 28 | + if self.deprecatedDagrunMode: |
| 29 | + # Throw warning about using deprecated dagrun |
| 30 | + warnings.warn("Creating Dag without any of LOAD, PERSIST and ROUTING arguments" |
| 31 | + "is allowed only in deprecated AI.DAGRUN or AI.DAGRUN_RO commands", DeprecationWarning) |
| 32 | + # Use dagrun |
| 33 | + if readonly: |
| 34 | + self.commands = ["AI.DAGRUN_RO"] |
| 35 | + else: |
| 36 | + self.commands = ["AI.DAGRUN"] |
23 | 37 | else:
|
24 |
| - self.commands = ["AI.DAGRUN"] |
25 |
| - if load: |
| 38 | + # Use dagexecute |
| 39 | + if readonly: |
| 40 | + self.commands = ["AI.DAGEXECUTE_RO"] |
| 41 | + else: |
| 42 | + self.commands = ["AI.DAGEXECUTE"] |
| 43 | + if load is not None: |
26 | 44 | if not isinstance(load, (list, tuple)):
|
27 | 45 | self.commands += ["LOAD", 1, load]
|
28 | 46 | else:
|
29 | 47 | self.commands += ["LOAD", len(load), *load]
|
30 |
| - if persist: |
| 48 | + if persist is not None: |
31 | 49 | if not isinstance(persist, (list, tuple)):
|
32 |
| - self.commands += ["PERSIST", 1, persist, "|>"] |
| 50 | + self.commands += ["PERSIST", 1, persist] |
33 | 51 | else:
|
34 |
| - self.commands += ["PERSIST", len(persist), *persist, "|>"] |
35 |
| - else: |
36 |
| - self.commands.append("|>") |
37 |
| - self.executor = executor |
| 52 | + self.commands += ["PERSIST", len(persist), *persist] |
| 53 | + if routing is not None: |
| 54 | + self.commands += ["ROUTING", routing] |
| 55 | + if timeout is not None: |
| 56 | + self.commands += ["TIMEOUT", timeout] |
| 57 | + |
| 58 | + self.commands.append("|>") |
38 | 59 |
|
39 | 60 | def tensorset(
|
40 | 61 | self,
|
@@ -69,20 +90,71 @@ def tensorget(
|
69 | 90 | )
|
70 | 91 | return self
|
71 | 92 |
|
| 93 | + @deprecated(version="1.2.0", reason="Use modelexecute instead") |
72 | 94 | def modelrun(
|
| 95 | + self, |
| 96 | + key: AnyStr, |
| 97 | + inputs: Union[AnyStr, List[AnyStr]], |
| 98 | + outputs: Union[AnyStr, List[AnyStr]], |
| 99 | + ) -> Any: |
| 100 | + if self.deprecatedDagrunMode: |
| 101 | + args = builder.modelrun(key, inputs, outputs) |
| 102 | + self.commands.extend(args) |
| 103 | + self.commands.append("|>") |
| 104 | + self.result_processors.append(bytes.decode) |
| 105 | + return self |
| 106 | + else: |
| 107 | + return self.modelexecute(key, inputs, outputs) |
| 108 | + |
| 109 | + def modelexecute( |
73 | 110 | self,
|
74 | 111 | key: AnyStr,
|
75 | 112 | inputs: Union[AnyStr, List[AnyStr]],
|
76 | 113 | outputs: Union[AnyStr, List[AnyStr]],
|
77 | 114 | ) -> Any:
|
78 |
| - args = builder.modelrun(key, inputs, outputs) |
| 115 | + if self.deprecatedDagrunMode: |
| 116 | + raise RuntimeError( |
| 117 | + "You are using deprecated version of DAG, that does not supports MODELEXECUTE." |
| 118 | + "The new version requires giving at least one of LOAD, PERSIST and ROUTING" |
| 119 | + "arguments when constructing the Dag" |
| 120 | + ) |
| 121 | + args = builder.modelexecute(key, inputs, outputs, None) |
79 | 122 | self.commands.extend(args)
|
80 | 123 | self.commands.append("|>")
|
81 | 124 | self.result_processors.append(bytes.decode)
|
82 | 125 | return self
|
83 | 126 |
|
| 127 | + def scriptexecute( |
| 128 | + self, |
| 129 | + key: AnyStr, |
| 130 | + function: str, |
| 131 | + keys: Union[AnyStr, Sequence[AnyStr]] = None, |
| 132 | + inputs: Union[AnyStr, Sequence[AnyStr]] = None, |
| 133 | + args: Union[AnyStr, Sequence[AnyStr]] = None, |
| 134 | + outputs: Union[AnyStr, List[AnyStr]] = None, |
| 135 | + ) -> Any: |
| 136 | + if self.readonly: |
| 137 | + raise RuntimeError( |
| 138 | + "AI.SCRIPTEXECUTE cannot be used in readonly mode" |
| 139 | + ) |
| 140 | + if self.deprecatedDagrunMode: |
| 141 | + raise RuntimeError( |
| 142 | + "You are using deprecated version of DAG, that does not supports SCRIPTEXECUTE." |
| 143 | + "The new version requires giving at least one of LOAD, PERSIST and ROUTING" |
| 144 | + "arguments when constructing the Dag" |
| 145 | + ) |
| 146 | + args = builder.scriptexecute(key, function, keys, inputs, args, outputs, None) |
| 147 | + self.commands.extend(args) |
| 148 | + self.commands.append("|>") |
| 149 | + self.result_processors.append(bytes.decode) |
| 150 | + return self |
| 151 | + |
| 152 | + @deprecated(version="1.2.0", reason="Use execute instead") |
84 | 153 | def run(self):
|
85 |
| - commands = self.commands[:-1] # removing the last "|> |
| 154 | + return self.execute() |
| 155 | + |
| 156 | + def execute(self): |
| 157 | + commands = self.commands[:-1] # removing the last "|>" |
86 | 158 | results = self.executor(*commands)
|
87 | 159 | if self.enable_postprocess:
|
88 | 160 | out = []
|
|
0 commit comments