diff --git a/jina/clients/base/__init__.py b/jina/clients/base/__init__.py index 51845502f49a9..8344f4231a456 100644 --- a/jina/clients/base/__init__.py +++ b/jina/clients/base/__init__.py @@ -5,7 +5,7 @@ import inspect import os from abc import ABC -from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union +from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union, Tuple from jina.excepts import BadClientInput from jina.helper import T, parse_client, send_telemetry_event, typename @@ -47,8 +47,6 @@ def __init__( # affect users os-level envs. os.unsetenv('http_proxy') os.unsetenv('https_proxy') - self._inputs = None - self._inputs_length = None self._setup_instrumentation( name=( self.args.name @@ -125,60 +123,43 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None: raise BadClientInput from ex def _get_requests( - self, **kwargs - ) -> Union[Iterator['Request'], AsyncIterator['Request']]: + self, inputs, **kwargs + ) -> Tuple[Union[Iterator['Request'], AsyncIterator['Request']], Optional[int]]: """ Get request in generator. + :param inputs: The inputs argument to get the requests from. :param kwargs: Keyword arguments. - :return: Iterator of request. + :return: Iterator of request and the length of the inputs. """ _kwargs = vars(self.args) - _kwargs['data'] = self.inputs + if hasattr(inputs, '__call__'): + inputs = inputs() + + _kwargs['data'] = inputs # override by the caller-specific kwargs _kwargs.update(kwargs) - if hasattr(self._inputs, '__len__'): - total_docs = len(self._inputs) + if hasattr(inputs, '__len__'): + total_docs = len(inputs) elif 'total_docs' in _kwargs: total_docs = _kwargs['total_docs'] else: total_docs = None if total_docs: - self._inputs_length = max(1, total_docs / _kwargs['request_size']) + inputs_length = max(1, total_docs / _kwargs['request_size']) + else: + inputs_length = None - if inspect.isasyncgen(self.inputs): + if inspect.isasyncgen(inputs): from jina.clients.request.asyncio import request_generator - return request_generator(**_kwargs) + return request_generator(**_kwargs), inputs_length else: from jina.clients.request import request_generator - return request_generator(**_kwargs) - - @property - def inputs(self) -> 'InputType': - """ - An iterator of bytes, each element represents a Document's raw content. - - ``inputs`` defined in the protobuf - - :return: inputs - """ - return self._inputs - - @inputs.setter - def inputs(self, bytes_gen: 'InputType') -> None: - """ - Set the input data. - - :param bytes_gen: input type - """ - if hasattr(bytes_gen, '__call__'): - self._inputs = bytes_gen() - else: - self._inputs = bytes_gen + return request_generator(**_kwargs), inputs_length @abc.abstractmethod async def _get_results( diff --git a/jina/clients/base/grpc.py b/jina/clients/base/grpc.py index 204924a57f74d..917950d05c4fd 100644 --- a/jina/clients/base/grpc.py +++ b/jina/clients/base/grpc.py @@ -90,8 +90,7 @@ async def _get_results( else grpc.Compression.NoCompression ) - self.inputs = inputs - req_iter = self._get_requests(**kwargs) + req_iter, inputs_length = self._get_requests(inputs=inputs, **kwargs) continue_on_error = self.continue_on_error # while loop with retries, check in which state the `iterator` remains after failure options = client_grpc_options( @@ -120,7 +119,7 @@ async def _get_results( self.logger.debug(f'connected to {self.args.host}:{self.args.port}') with ProgressBar( - total_length=self._inputs_length, disable=not self.show_progress + total_length=inputs_length, disable=not self.show_progress ) as p_bar: try: if stream: diff --git a/jina/clients/base/http.py b/jina/clients/base/http.py index c10cb40749e27..49cfa7461886f 100644 --- a/jina/clients/base/http.py +++ b/jina/clients/base/http.py @@ -153,15 +153,14 @@ async def _get_results( with ImportExtensions(required=True): pass - self.inputs = inputs - request_iterator = self._get_requests(**kwargs) + request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs) on = kwargs.get('on', '/post') if len(self._endpoints) == 0: await self._get_endpoints_from_openapi(**kwargs) async with AsyncExitStack() as stack: cm1 = ProgressBar( - total_length=self._inputs_length, disable=not self.show_progress + total_length=inputs_length, disable=not self.show_progress ) p_bar = stack.enter_context(cm1) proto = 'https' if self.args.tls else 'http' diff --git a/jina/clients/base/websocket.py b/jina/clients/base/websocket.py index a8b868704bac0..01d58b52609f6 100644 --- a/jina/clients/base/websocket.py +++ b/jina/clients/base/websocket.py @@ -108,12 +108,11 @@ async def _get_results( with ImportExtensions(required=True): pass - self.inputs = inputs - request_iterator = self._get_requests(**kwargs) + request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs) async with AsyncExitStack() as stack: cm1 = ProgressBar( - total_length=self._inputs_length, disable=not (self.show_progress) + total_length=inputs_length, disable=not (self.show_progress) ) p_bar = stack.enter_context(cm1)