Skip to content

Commit

Permalink
Run pre/post processing in threadpool (#7327)
Browse files Browse the repository at this point in the history
* Add code

* Add code

* Add code

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot authored Feb 15, 2024
1 parent 7b84bc4 commit fb1f6be
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 13 deletions.
5 changes: 5 additions & 0 deletions .changeset/hot-taxis-jump.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Run pre/post processing in threadpool
38 changes: 29 additions & 9 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,9 @@ def handle_streaming_diffs(

return data

def run_fn_batch(self, fn, batch, fn_index, state):
return [fn(fn_index, list(i), state) for i in zip(*batch)]

async def process_api(
self,
fn_index: int,
Expand Down Expand Up @@ -1565,10 +1568,14 @@ async def process_api(
raise ValueError(
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
)

inputs = [
self.preprocess_data(fn_index, list(i), state) for i in zip(*inputs)
]
inputs = await anyio.to_thread.run_sync(
self.run_fn_batch,
self.preprocess_data,
inputs,
fn_index,
state,
limiter=self.limiter,
)
result = await self.call_function(
fn_index,
list(zip(*inputs)),
Expand All @@ -1579,17 +1586,24 @@ async def process_api(
in_event_listener,
)
preds = result["prediction"]
data = [
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
]
data = await anyio.to_thread.run_sync(
self.run_fn_batch,
self.postprocess_data,
preds,
fn_index,
state,
limiter=self.limiter,
)
data = list(zip(*data))
is_generating, iterator = None, None
else:
old_iterator = iterator
if old_iterator:
inputs = []
else:
inputs = self.preprocess_data(fn_index, inputs, state)
inputs = await anyio.to_thread.run_sync(
self.preprocess_data, fn_index, inputs, state, limiter=self.limiter
)
was_generating = old_iterator is not None
result = await self.call_function(
fn_index,
Expand All @@ -1600,7 +1614,13 @@ async def process_api(
event_data,
in_event_listener,
)
data = self.postprocess_data(fn_index, result["prediction"], state)
data = await anyio.to_thread.run_sync(
self.postprocess_data,
fn_index, # type: ignore
result["prediction"],
state,
limiter=self.limiter,
)
is_generating, iterator = result["is_generating"], result["iterator"]
if is_generating or was_generating:
run = id(old_iterator) if was_generating else id(iterator)
Expand Down
11 changes: 8 additions & 3 deletions gradio/components/gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -165,7 +166,8 @@ def postprocess(
if value is None:
return GalleryData(root=[])
output = []
for img in value:

def _save(img):
url = None
caption = None
orig_name = None
Expand Down Expand Up @@ -194,11 +196,14 @@ def postprocess(
orig_name = img.name
else:
raise ValueError(f"Cannot process type as image: {type(img)}")
entry = GalleryImage(
return GalleryImage(
image=FileData(path=file_path, url=url, orig_name=orig_name),
caption=caption,
)
output.append(entry)

with ThreadPoolExecutor() as executor:
for o in executor.map(_save, value):
output.append(o)
return GalleryData(root=output)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def save_pil_to_cache(
temp_dir = Path(cache_dir) / hash_bytes(bytes_data)
temp_dir.mkdir(exist_ok=True, parents=True)
filename = str((temp_dir / f"{name}.{format}").resolve())
img.save(filename, pnginfo=get_pil_metadata(img))
(temp_dir / f"{name}.{format}").resolve().write_bytes(bytes_data)
return filename


Expand Down

0 comments on commit fb1f6be

Please sign in to comment.