Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix optim issue with optimization #6184

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ async def _sleep_then_set(self):
self._flush_trigger.set()
self._timer_finished = True

async def push(self, request: DataRequest) -> asyncio.Queue:
async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
"""Append request to the the list of requests to be processed.

This method creates an asyncio Queue for that request and keeps track of it. It returns
this queue to the caller so that the caller can now when this request has been processed

:param request: The request to append to the queue.
:param http: Flag to determine if the request is served via HTTP for some optims

:return: The queue that will receive when the request is processed.
"""
Expand All @@ -103,7 +104,7 @@ async def push(self, request: DataRequest) -> asyncio.Queue:
self._start_timer()
async with self._data_lock:
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush())
self._flush_task = asyncio.create_task(self._await_then_flush(http))

self._big_doc.extend(docs)
next_req_idx = len(self._requests)
Expand All @@ -118,8 +119,10 @@ async def push(self, request: DataRequest) -> asyncio.Queue:

return queue

async def _await_then_flush(self) -> None:
"""Process all requests in the queue once flush_trigger event is set."""
async def _await_then_flush(self, http=False) -> None:
"""Process all requests in the queue once flush_trigger event is set.
:param http: Flag to determine if the request is served via HTTP for some optims
"""

def _get_docs_groups_completed_request_indexes(
non_assigned_docs,
Expand Down Expand Up @@ -200,9 +203,13 @@ async def _assign_results(
for docs_group, request_idx in zip(docs_grouped, completed_req_idxs):
request = self._requests[request_idx]
request_completed = self._requests_completed[request_idx]
request.data.set_docs_convert_arrays(
docs_group, ndarray_type=self._output_array_type
)
if http is False or self._output_array_type is not None:
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
request.data.set_docs_convert_arrays(
docs_group, ndarray_type=self._output_array_type
)
else:
request.direct_docs = docs_group
await request_completed.put(None)

return num_assigned_docs
Expand Down
3 changes: 2 additions & 1 deletion jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,9 @@ async def handle(
**self._batchqueue_config[exec_endpoint],
)
# This is necessary because push might need to await for the queue to be emptied
# the batch queue will change the request in-place
queue = await self._batchqueue_instances[exec_endpoint][param_key].push(
requests[0]
requests[0], http=http
)
item = await queue.get()
queue.task_done()
Expand Down
Loading
Loading