Skip to content

Update inference widget from adapter-transformers to Adapters #422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docker_images/adapter_transformers/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from app.pipelines import (
Pipeline,
QuestionAnsweringPipeline,
SummarizationPipeline,
TextClassificationPipeline,
TextGenerationPipeline,
TokenClassificationPipeline,
)
from starlette.applications import Starlette
Expand Down Expand Up @@ -39,7 +41,9 @@
# directories. Implement directly within the directories.
ALLOWED_TASKS: Dict[str, Type[Pipeline]] = {
"question-answering": QuestionAnsweringPipeline,
"summarization": SummarizationPipeline,
"text-classification": TextClassificationPipeline,
"text-generation": TextGenerationPipeline,
"token-classification": TokenClassificationPipeline,
# IMPLEMENT_THIS: Add your implemented tasks here !
}
Expand Down
2 changes: 2 additions & 0 deletions docker_images/adapter_transformers/app/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from app.pipelines.base import Pipeline, PipelineException # isort:skip

from app.pipelines.question_answering import QuestionAnsweringPipeline
from app.pipelines.summarization import SummarizationPipeline
from app.pipelines.text_classification import TextClassificationPipeline
from app.pipelines.text_generation import TextGenerationPipeline
from app.pipelines.token_classification import TokenClassificationPipeline
14 changes: 12 additions & 2 deletions docker_images/adapter_transformers/app/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any

from transformers import AutoModelWithHeads, AutoTokenizer, get_adapter_info
from adapters import AutoAdapterModel, get_adapter_info
from transformers import AutoTokenizer
from transformers.pipelines.base import logger


class Pipeline(ABC):
Expand All @@ -20,8 +22,16 @@ def _load_pipeline_instance(pipeline_class, adapter_id):
raise ValueError(f"Adapter with id '{adapter_id}' not available.")

tokenizer = AutoTokenizer.from_pretrained(adapter_info.model_name)
model = AutoModelWithHeads.from_pretrained(adapter_info.model_name)
model = AutoAdapterModel.from_pretrained(adapter_info.model_name)
model.load_adapter(adapter_id, source="hf", set_active=True)

# Transformers incorrectly logs an error because class name is not known. Filter this out.
logger.addFilter(
lambda record: not record.getMessage().startswith(
f"The model '{model.__class__.__name__}' is not supported"
)
)

return pipeline_class(model=model, tokenizer=tokenizer)


Expand Down
20 changes: 20 additions & 0 deletions docker_images/adapter_transformers/app/pipelines/summarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Dict, List

from app.pipelines import Pipeline
from transformers import SummarizationPipeline as TransformersSummarizationPipeline


class SummarizationPipeline(Pipeline):
def __init__(self, adapter_id: str):
self.pipeline = self._load_pipeline_instance(
TransformersSummarizationPipeline, adapter_id
)

def __call__(self, inputs: str) -> List[Dict[str, str]]:
"""
Args:
inputs (:obj:`str`): a string to be summarized
Return:
A :obj:`list` of :obj:`dict` in the form of {"summary_text": "The string after summarization"}
"""
return self.pipeline(inputs, truncation=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Dict, List

from app.pipelines import Pipeline
from transformers import TextGenerationPipeline as TransformersTextGenerationPipeline


class TextGenerationPipeline(Pipeline):
def __init__(self, adapter_id: str):
self.pipeline = self._load_pipeline_instance(
TransformersTextGenerationPipeline, adapter_id
)

def __call__(self, inputs: str) -> List[Dict[str, str]]:
"""
Args:
inputs (:obj:`str`):
The input text
Return:
A :obj:`list`:. The list contains a single item that is a dict {"text": the model output}
"""
return self.pipeline(inputs, truncation=True)
10 changes: 5 additions & 5 deletions docker_images/adapter_transformers/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
starlette==0.27.0
api-inference-community==0.0.23
torch==1.13.1
adapter-transformers==2.2.0
huggingface_hub==0.5.1
starlette==0.37.2
api-inference-community==0.0.32
torch==2.3.0
adapters==0.2.1
huggingface_hub==0.23.0
4 changes: 3 additions & 1 deletion docker_images/adapter_transformers/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
# Tests do not check the actual values of the model output, so small dummy
# models are recommended for faster tests.
TESTABLE_MODELS: Dict[str, str] = {
"question-answering": "calpt/adapter-bert-base-squad1",
"question-answering": "AdapterHub/roberta-base-pf-squad",
"summarization": "AdapterHub/facebook-bart-large_sum_xsum_pfeiffer",
"text-classification": "AdapterHub/roberta-base-pf-sick",
"text-generation": "AdapterHub/gpt2_lm_poem_pfeiffer",
"token-classification": "AdapterHub/roberta-base-pf-conll2003",
}

Expand Down
75 changes: 75 additions & 0 deletions docker_images/adapter_transformers/tests/test_api_summarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
import os
from unittest import TestCase, skipIf

from app.main import ALLOWED_TASKS
from starlette.testclient import TestClient
from tests.test_api import TESTABLE_MODELS


@skipIf(
"summarization" not in ALLOWED_TASKS,
"summarization not implemented",
)
class SummarizationTestCase(TestCase):
def setUp(self):
model_id = TESTABLE_MODELS["summarization"]
self.old_model_id = os.getenv("MODEL_ID")
self.old_task = os.getenv("TASK")
os.environ["MODEL_ID"] = model_id
os.environ["TASK"] = "summarization"
from app.main import app

self.app = app

@classmethod
def setUpClass(cls):
from app.main import get_pipeline

get_pipeline.cache_clear()

def tearDown(self):
if self.old_model_id is not None:
os.environ["MODEL_ID"] = self.old_model_id
else:
del os.environ["MODEL_ID"]
if self.old_task is not None:
os.environ["TASK"] = self.old_task
else:
del os.environ["TASK"]

def test_simple(self):
inputs = "The weather is nice today."

with TestClient(self.app) as client:
response = client.post("/", json={"inputs": inputs})

self.assertEqual(
response.status_code,
200,
)
content = json.loads(response.content)
self.assertEqual(type(content), list)
self.assertEqual(type(content[0]["summary_text"]), str)

with TestClient(self.app) as client:
response = client.post("/", json=inputs)

self.assertEqual(
response.status_code,
200,
)
content = json.loads(response.content)
self.assertEqual(type(content), list)
self.assertEqual(type(content[0]["summary_text"]), str)

def test_malformed_question(self):
with TestClient(self.app) as client:
response = client.post("/", data=b"\xc3\x28")

self.assertEqual(
response.status_code,
400,
)
content = json.loads(response.content)
self.assertEqual(set(content.keys()), {"error"})
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import json
import os
from unittest import TestCase, skipIf

from app.main import ALLOWED_TASKS
from starlette.testclient import TestClient
from tests.test_api import TESTABLE_MODELS


@skipIf(
"text-generation" not in ALLOWED_TASKS,
"text-generation not implemented",
)
class TextGenerationTestCase(TestCase):
def setUp(self):
model_id = TESTABLE_MODELS["text-generation"]
self.old_model_id = os.getenv("MODEL_ID")
self.old_task = os.getenv("TASK")
os.environ["MODEL_ID"] = model_id
os.environ["TASK"] = "text-generation"
from app.main import app

self.app = app

@classmethod
def setUpClass(cls):
from app.main import get_pipeline

get_pipeline.cache_clear()

def tearDown(self):
if self.old_model_id is not None:
os.environ["MODEL_ID"] = self.old_model_id
else:
del os.environ["MODEL_ID"]
if self.old_task is not None:
os.environ["TASK"] = self.old_task
else:
del os.environ["TASK"]

def test_simple(self):
inputs = "The weather is nice today."

with TestClient(self.app) as client:
response = client.post("/", json={"inputs": inputs})

self.assertEqual(
response.status_code,
200,
)
content = json.loads(response.content)
self.assertEqual(type(content), list)

with TestClient(self.app) as client:
response = client.post("/", json=inputs)

self.assertEqual(
response.status_code,
200,
)
content = json.loads(response.content)
self.assertEqual(type(content), list)
self.assertEqual(type(content[0]["generated_text"]), str)

def test_malformed_question(self):
with TestClient(self.app) as client:
response = client.post("/", data=b"\xc3\x28")

self.assertEqual(
response.status_code,
400,
)
content = json.loads(response.content)
self.assertEqual(set(content.keys()), {"error"})
14 changes: 13 additions & 1 deletion tests/test_dockers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,13 @@ def test_adapter_transformers(self):
self.framework_docker_test(
"adapter_transformers",
"question-answering",
"calpt/adapter-bert-base-squad1",
"AdapterHub/roberta-base-pf-squad",
)

self.framework_docker_test(
"adapter_transformers",
"summarization",
"AdapterHub/facebook-bart-large_sum_xsum_pfeiffer",
)

self.framework_docker_test(
Expand All @@ -145,6 +151,12 @@ def test_adapter_transformers(self):
"AdapterHub/roberta-base-pf-sick",
)

self.framework_docker_test(
"adapter_transformers",
"text-generation",
"AdapterHub/gpt2_lm_poem_pfeiffer",
)

self.framework_docker_test(
"adapter_transformers",
"token-classification",
Expand Down
Loading