Skip to content

Commit

Permalink
feat(hubio): allow uses_ override args in executors from_hub (#4046)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Dec 8, 2021
1 parent 88969ae commit 7574e99
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 21 deletions.
24 changes: 21 additions & 3 deletions jina/executors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import os
from types import SimpleNamespace
from typing import Dict, Optional, Type
from typing import Dict, Optional, Type, Any

from .decorators import store_init_kwargs, wrap_func
from .. import __default_endpoint__, __args_executor_init__
Expand Down Expand Up @@ -231,10 +231,22 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

@classmethod
def from_hub(cls: Type[T], uri: str, **kwargs) -> T:
def from_hub(
cls: Type[T],
uri: str,
context: Optional[Dict[str, Any]] = None,
uses_with: Optional[Dict] = None,
uses_metas: Optional[Dict] = None,
uses_requests: Optional[Dict] = None,
**kwargs,
) -> T:
"""Construct an Executor from Hub.
:param uri: a hub Executor scheme starts with `jinahub://`
:param context: context replacement variables in a dict, the value of the dict is the replacement.
:param uses_with: dictionary of parameters to overwrite from the default config's with field
:param uses_metas: dictionary of parameters to overwrite from the default config's metas field
:param uses_requests: dictionary of parameters to overwrite from the default config's requests field
:param kwargs: other kwargs accepted by the CLI ``jina hub pull``
:return: the Hub Executor object.
"""
Expand All @@ -257,4 +269,10 @@ def from_hub(cls: Type[T], uri: str, **kwargs) -> T:
f'Can not construct a native Executor from {uri}. Looks like you want to use it as a '
f'Docker container, you may want to use it in the Flow via `.add(uses={uri})` instead.'
)
return cls.load_config(_source)
return cls.load_config(
_source,
context=context,
uses_with=uses_with,
uses_metas=uses_metas,
uses_requests=uses_requests,
)
24 changes: 12 additions & 12 deletions jina/jaml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,9 @@ def load_config(
allow_py_modules: bool = True,
substitute: bool = True,
context: Optional[Dict[str, Any]] = None,
override_with: Optional[Dict] = None,
override_metas: Optional[Dict] = None,
override_requests: Optional[Dict] = None,
uses_with: Optional[Dict] = None,
uses_metas: Optional[Dict] = None,
uses_requests: Optional[Dict] = None,
extra_search_paths: Optional[List[str]] = None,
**kwargs,
) -> 'JAMLCompatible':
Expand Down Expand Up @@ -527,9 +527,9 @@ def load_config(
:param allow_py_modules: allow importing plugins specified by ``py_modules`` in YAML at any levels
:param substitute: substitute environment, internal reference and context variables.
:param context: context replacement variables in a dict, the value of the dict is the replacement.
:param override_with: dictionary of parameters to overwrite from the default config's with field
:param override_metas: dictionary of parameters to overwrite from the default config's metas field
:param override_requests: dictionary of parameters to overwrite from the default config's requests field
:param uses_with: dictionary of parameters to overwrite from the default config's with field
:param uses_metas: dictionary of parameters to overwrite from the default config's metas field
:param uses_requests: dictionary of parameters to overwrite from the default config's requests field
:param extra_search_paths: extra paths used when looking for executor yaml files
:param kwargs: kwargs for parse_config_source
:return: :class:`JAMLCompatible` object
Expand Down Expand Up @@ -560,15 +560,15 @@ def _delitem(
if isinstance(v, dict):
_delitem(v, key)

if override_with is not None:
if uses_with is not None:
_delitem(no_tag_yml, key='uses_with')
if override_metas is not None:
if uses_metas is not None:
_delitem(no_tag_yml, key='uses_metas')
if override_requests is not None:
if uses_requests is not None:
_delitem(no_tag_yml, key='uses_requests')
cls._override_yml_params(no_tag_yml, 'with', override_with)
cls._override_yml_params(no_tag_yml, 'metas', override_metas)
cls._override_yml_params(no_tag_yml, 'requests', override_requests)
cls._override_yml_params(no_tag_yml, 'with', uses_with)
cls._override_yml_params(no_tag_yml, 'metas', uses_metas)
cls._override_yml_params(no_tag_yml, 'requests', uses_requests)

else:
raise BadConfigSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def _load_executor(self):
try:
self._executor = BaseExecutor.load_config(
self.args.uses,
override_with=self.args.uses_with,
override_metas=self.args.uses_metas,
override_requests=self.args.uses_requests,
uses_with=self.args.uses_with,
uses_metas=self.args.uses_metas,
uses_requests=self.args.uses_requests,
runtime_args=vars(self.args),
extra_search_paths=self.args.extra_search_paths,
)
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@


def test_executor_load_from_hub():
exec = Executor.from_hub('jinahub://DummyHubExecutor')
exec = Executor.from_hub(
'jinahub://DummyHubExecutor', uses_metas={'name': 'hello123'}
)
da = DocumentArray([Document()])
exec.foo(da)
assert da.texts == ['hello']
assert exec.metas.name == 'hello123'


def test_executor_import_with_external_dependencies(capsys):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/flow-construct/test_flow_yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_flow_yaml_override_with_protocol():
path = os.path.join(cur_dir.parent, 'yaml/examples/faiss/flow-index.yml')
f1 = Flow.load_config(path)
assert f1.protocol == GatewayProtocolType.GRPC
f2 = Flow.load_config(path, override_with={'protocol': 'http'})
f2 = Flow.load_config(path, uses_with={'protocol': 'http'})
assert f2.protocol == GatewayProtocolType.HTTP
f3 = Flow.load_config(path, override_with={'protocol': 'websocket'})
f3 = Flow.load_config(path, uses_with={'protocol': 'websocket'})
assert f3.protocol == GatewayProtocolType.WEBSOCKET

0 comments on commit 7574e99

Please sign in to comment.