Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion redisai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,44 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
'OK'
"""
args = builder.loadbackend(identifier, path)
res = self.execute_command(*args)
res = self.execute_command(args)
return res if not self.enable_postprocess else processor.loadbackend(res)

def config(self, name: str, value: Union[str, int, None] = None) -> str:
"""
Get/Set configuration item. Current available configurations are: BACKENDSPATH and MODEL_CHUNK_SIZE.
For more details, see: https://oss.redis.com/redisai/master/commands/#aiconfig.
If value is given - the configuration under name will be overriten.

Parameters
----------
name: str
RedisAI config item to retreive/override (BACKENDSPATH / MODEL_CHUNK_SIZE).
value: Union[str, int]
Value to set the config item with (if given).

Returns
-------
The current configuration value if value is None,
'OK' if value was given and configuration overitten succeeded,
raise an exception otherwise


Example
-------
>>> con.config('MODEL_CHUNK_SIZE', 128 * 1024)
'OK'
>>> con.config('BACKENDSPATH', '/my/backends/path')
'OK'
>>> con.config('BACKENDSPATH')
'/my/backends/path'
>>> con.config('MODEL_CHUNK_SIZE')
'131072'
"""
args = builder.config(name, value)
res = self.execute_command(args)
return res if not self.enable_postprocess or not isinstance(res, bytes) else processor.config(res)

def modelstore(
self,
key: AnyStr,
Expand Down Expand Up @@ -209,6 +244,7 @@ def modelstore(
... inputs=['a', 'b'], outputs=['mul'], tag='v1.0')
'OK'
"""
chunk_size = self.config('MODEL_CHUNK_SIZE')
args = builder.modelstore(
key,
backend,
Expand All @@ -220,6 +256,7 @@ def modelstore(
tag,
inputs,
outputs,
chunk_size=chunk_size
)
res = self.execute_command(*args)
return res if not self.enable_postprocess else processor.modelstore(res)
Expand Down
11 changes: 8 additions & 3 deletions redisai/command_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@


def loadbackend(identifier: AnyStr, path: AnyStr) -> Sequence:
return "AI.CONFIG LOADBACKEND", identifier, path
return f'AI.CONFIG LOADBACKEND {identifier} {path}'


def config(name: str, value: Union[str, int, None] = None) -> Sequence:
if value is not None:
return f'AI.CONFIG {name} {value}'
return f'AI.CONFIG GET {name}'


def modelstore(
Expand All @@ -22,6 +28,7 @@ def modelstore(
tag: AnyStr,
inputs: Union[AnyStr, List[AnyStr]],
outputs: Union[AnyStr, List[AnyStr]],
chunk_size: int = 500 * 1024 * 1024
) -> Sequence:
if name is None:
raise ValueError("Model name was not given")
Expand Down Expand Up @@ -66,9 +73,7 @@ def modelstore(
raise ValueError(
"Inputs and outputs keywords should not be specified for this backend"
)
chunk_size = 500 * 1024 * 1024 # TODO: this should be configurable.
data_chunks = [data[i: i + chunk_size] for i in range(0, len(data), chunk_size)]
# TODO: need a test case for this
args += ["BLOB", *data_chunks]
return args

Expand Down
33 changes: 17 additions & 16 deletions redisai/postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,28 @@ def tensorget(res, as_numpy, as_numpy_mutable, meta_only):
rai_result = utils.list2dict(res)
if meta_only is True:
return rai_result
elif as_numpy_mutable is True:
if as_numpy_mutable is True:
return utils.blob2numpy(
rai_result["blob"],
rai_result["shape"],
rai_result["dtype"],
mutable=True,
)
elif as_numpy is True:
if as_numpy is True:
return utils.blob2numpy(
rai_result["blob"],
rai_result["shape"],
rai_result["dtype"],
mutable=False,
)

if rai_result["dtype"] == "STRING":
def target(b):
return b.decode()
else:
if rai_result["dtype"] == "STRING":
def target(b):
return b.decode()
else:
target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int
utils.recursive_bytetransform(rai_result["values"], target)
return rai_result
target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int
utils.recursive_bytetransform(rai_result["values"], target)
return rai_result

@staticmethod
def scriptget(res):
Expand All @@ -66,19 +66,20 @@ def infoget(res):
# These functions are only doing decoding on the output from redis
decoder = staticmethod(decoder)
decoding_functions = (
"config",
"inforeset",
"loadbackend",
"modelstore",
"modelset",
"modeldel",
"modelexecute",
"modelrun",
"tensorset",
"scriptset",
"scriptstore",
"modelset",
"modelstore",
"scriptdel",
"scriptrun",
"scriptexecute",
"inforeset",
"scriptrun",
"scriptset",
"scriptstore",
"tensorset",
)
for fn in decoding_functions:
setattr(Processor, fn, decoder)
Loading