Skip to content

Commit

Permalink
Extend Interface.from_pipeline() to support Transformers.js.py pipeli…
Browse files Browse the repository at this point in the history
…nes on Lite (gradio-app#8052)

* Extend Interface.from_pipeline() to support Transformers.js.py pipelines on Lite (wip: only object-detection in this commit)

* add changeset

* Add image-classification and image-segmentation

* Add zero-shot-image-classification and zero-shot-object-detection

* Add document-question-answering

* Add feature-extraction and fill-mask

* Add question-answering and summarization

* Fix an error message

* Add text2text-generation, text-classification, and text-generation

* Add translation andtranslation_xx_to_yy

* Add zero-shot-classification

* Add postprocess_takes_inputs to control the args passed to the postprocess function of each pipeline

* Add topk option to image-classification

* format_backend

* Add audio-classification, automatic-speech-recognition, and zero-shot-audio-classification

* Add image-to-text

* Add token-classification (with JSON component as an output. Is it correct?)

* Ignore import type failure of transformers_js_py

* Add image-feature-extraction

* Add image-to-image

* Add text-to-audio

* Add depth-estimation

* Remove `render=False`

* Reorder the if-blocks following the Transformers.js doc

* Update gradio/pipelines_utils.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/pipelines_utils.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Fix feature-extraction demo

* Fix demo title

* Add guides/08_gradio-clients-and-lite/gradio-lite-and-transformers-js.md without contents

* Rename guides/08_gradio-clients-and-lite/*.md to fix the order

* Use pipeline.model.config._name_or_path for the demo title instead of pipeline.model.config.model_type

* Fix normal Interface.from_pipeline to use pipeline.model.config.name_or_path as the demo title

* Write an article about Gradio-Lite and Transformers.js

* Update the doc

* tweaks

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people authored May 3, 2024
1 parent cfc272f commit 1435d1d
Show file tree
Hide file tree
Showing 9 changed files with 599 additions and 5 deletions.
6 changes: 6 additions & 0 deletions .changeset/hungry-icons-serve.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/lite": minor
"gradio": minor
---

feat:Extend Interface.from_pipeline() to support Transformers.js.py pipelines on Lite
9 changes: 6 additions & 3 deletions gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from gradio_client.documentation import document

from gradio import Examples, utils
from gradio import Examples, utils, wasm_utils
from gradio.blocks import Blocks
from gradio.components import (
Button,
Expand All @@ -29,7 +29,7 @@
from gradio.exceptions import RenderError
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
from gradio.layouts import Accordion, Column, Row, Tab, Tabs
from gradio.pipelines import load_from_pipeline
from gradio.pipelines import load_from_js_pipeline, load_from_pipeline
from gradio.themes import ThemeClass as Theme

if TYPE_CHECKING: # Only import for type checking (is False at runtime).
Expand Down Expand Up @@ -85,7 +85,10 @@ def from_pipeline(
pipe = pipeline("image-classification")
gr.Interface.from_pipeline(pipe).launch()
"""
interface_info = load_from_pipeline(pipeline)
if wasm_utils.IS_WASM:
interface_info = load_from_js_pipeline(pipeline)
else:
interface_info = load_from_pipeline(pipeline)
kwargs = dict(interface_info, **kwargs)
interface = cls(**kwargs)
return interface
Expand Down
33 changes: 32 additions & 1 deletion gradio/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from gradio.pipelines_utils import (
handle_diffusers_pipeline,
handle_transformers_js_pipeline,
handle_transformers_pipeline,
)

Expand Down Expand Up @@ -77,9 +78,39 @@ def fn(*params):

# define the title/description of the Interface
interface_info["title"] = (
pipeline.model.__class__.__name__
pipeline.model.config.name_or_path
if str(type(pipeline).__module__).startswith("transformers.pipelines")
else pipeline.__class__.__name__
)

return interface_info


def load_from_js_pipeline(pipeline) -> dict:
if str(type(pipeline).__module__).startswith("transformers_js_py."):
pipeline_info = handle_transformers_js_pipeline(pipeline)
else:
raise ValueError("pipeline must be a transformers_js_py's pipeline")

async def fn(*params):
preprocess = pipeline_info["preprocess"]
postprocess = pipeline_info["postprocess"]
postprocess_takes_inputs = pipeline_info.get("postprocess_takes_inputs", False)

preprocessed_params = preprocess(*params) if preprocess else params
pipeline_output = await pipeline(*preprocessed_params)
postprocessed_output = (
postprocess(pipeline_output, *(params if postprocess_takes_inputs else ()))
if postprocess
else pipeline_output
)

return postprocessed_output

interface_info = {
"fn": fn,
"inputs": pipeline_info["inputs"],
"outputs": pipeline_info["outputs"],
"title": f"{pipeline.task} ({pipeline.model.config._name_or_path})",
}
return interface_info
Loading

0 comments on commit 1435d1d

Please sign in to comment.