Skip to content

Commit bc65526

Browse files
authored
Update inference widget from adapter-transformers to Adapters (#422)
* Migration from Adapter-Transformers to Adapters * Filter error message * rename back for now * update tests * Fix errors in Adapters * Update adapters package version
1 parent 48c0c4b commit bc65526

File tree

10 files changed

+229
-9
lines changed

10 files changed

+229
-9
lines changed

docker_images/adapter_transformers/app/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from app.pipelines import (
88
Pipeline,
99
QuestionAnsweringPipeline,
10+
SummarizationPipeline,
1011
TextClassificationPipeline,
12+
TextGenerationPipeline,
1113
TokenClassificationPipeline,
1214
)
1315
from starlette.applications import Starlette
@@ -39,7 +41,9 @@
3941
# directories. Implement directly within the directories.
4042
ALLOWED_TASKS: Dict[str, Type[Pipeline]] = {
4143
"question-answering": QuestionAnsweringPipeline,
44+
"summarization": SummarizationPipeline,
4245
"text-classification": TextClassificationPipeline,
46+
"text-generation": TextGenerationPipeline,
4347
"token-classification": TokenClassificationPipeline,
4448
# IMPLEMENT_THIS: Add your implemented tasks here !
4549
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from app.pipelines.base import Pipeline, PipelineException # isort:skip
22

33
from app.pipelines.question_answering import QuestionAnsweringPipeline
4+
from app.pipelines.summarization import SummarizationPipeline
45
from app.pipelines.text_classification import TextClassificationPipeline
6+
from app.pipelines.text_generation import TextGenerationPipeline
57
from app.pipelines.token_classification import TokenClassificationPipeline

docker_images/adapter_transformers/app/pipelines/base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from abc import ABC, abstractmethod
22
from typing import Any
33

4-
from transformers import AutoModelWithHeads, AutoTokenizer, get_adapter_info
4+
from adapters import AutoAdapterModel, get_adapter_info
5+
from transformers import AutoTokenizer
6+
from transformers.pipelines.base import logger
57

68

79
class Pipeline(ABC):
@@ -20,8 +22,16 @@ def _load_pipeline_instance(pipeline_class, adapter_id):
2022
raise ValueError(f"Adapter with id '{adapter_id}' not available.")
2123

2224
tokenizer = AutoTokenizer.from_pretrained(adapter_info.model_name)
23-
model = AutoModelWithHeads.from_pretrained(adapter_info.model_name)
25+
model = AutoAdapterModel.from_pretrained(adapter_info.model_name)
2426
model.load_adapter(adapter_id, source="hf", set_active=True)
27+
28+
# Transformers incorrectly logs an error because class name is not known. Filter this out.
29+
logger.addFilter(
30+
lambda record: not record.getMessage().startswith(
31+
f"The model '{model.__class__.__name__}' is not supported"
32+
)
33+
)
34+
2535
return pipeline_class(model=model, tokenizer=tokenizer)
2636

2737

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Dict, List
2+
3+
from app.pipelines import Pipeline
4+
from transformers import SummarizationPipeline as TransformersSummarizationPipeline
5+
6+
7+
class SummarizationPipeline(Pipeline):
8+
def __init__(self, adapter_id: str):
9+
self.pipeline = self._load_pipeline_instance(
10+
TransformersSummarizationPipeline, adapter_id
11+
)
12+
13+
def __call__(self, inputs: str) -> List[Dict[str, str]]:
14+
"""
15+
Args:
16+
inputs (:obj:`str`): a string to be summarized
17+
Return:
18+
A :obj:`list` of :obj:`dict` in the form of {"summary_text": "The string after summarization"}
19+
"""
20+
return self.pipeline(inputs, truncation=True)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Dict, List
2+
3+
from app.pipelines import Pipeline
4+
from transformers import TextGenerationPipeline as TransformersTextGenerationPipeline
5+
6+
7+
class TextGenerationPipeline(Pipeline):
8+
def __init__(self, adapter_id: str):
9+
self.pipeline = self._load_pipeline_instance(
10+
TransformersTextGenerationPipeline, adapter_id
11+
)
12+
13+
def __call__(self, inputs: str) -> List[Dict[str, str]]:
14+
"""
15+
Args:
16+
inputs (:obj:`str`):
17+
The input text
18+
Return:
19+
A :obj:`list`:. The list contains a single item that is a dict {"text": the model output}
20+
"""
21+
return self.pipeline(inputs, truncation=True)
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
starlette==0.27.0
2-
api-inference-community==0.0.23
3-
torch==1.13.1
4-
adapter-transformers==2.2.0
5-
huggingface_hub==0.5.1
1+
starlette==0.37.2
2+
api-inference-community==0.0.32
3+
torch==2.3.0
4+
adapters==0.2.1
5+
huggingface_hub==0.23.0

docker_images/adapter_transformers/tests/test_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
# Tests do not check the actual values of the model output, so small dummy
1010
# models are recommended for faster tests.
1111
TESTABLE_MODELS: Dict[str, str] = {
12-
"question-answering": "calpt/adapter-bert-base-squad1",
12+
"question-answering": "AdapterHub/roberta-base-pf-squad",
13+
"summarization": "AdapterHub/facebook-bart-large_sum_xsum_pfeiffer",
1314
"text-classification": "AdapterHub/roberta-base-pf-sick",
15+
"text-generation": "AdapterHub/gpt2_lm_poem_pfeiffer",
1416
"token-classification": "AdapterHub/roberta-base-pf-conll2003",
1517
}
1618

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import json
2+
import os
3+
from unittest import TestCase, skipIf
4+
5+
from app.main import ALLOWED_TASKS
6+
from starlette.testclient import TestClient
7+
from tests.test_api import TESTABLE_MODELS
8+
9+
10+
@skipIf(
11+
"summarization" not in ALLOWED_TASKS,
12+
"summarization not implemented",
13+
)
14+
class SummarizationTestCase(TestCase):
15+
def setUp(self):
16+
model_id = TESTABLE_MODELS["summarization"]
17+
self.old_model_id = os.getenv("MODEL_ID")
18+
self.old_task = os.getenv("TASK")
19+
os.environ["MODEL_ID"] = model_id
20+
os.environ["TASK"] = "summarization"
21+
from app.main import app
22+
23+
self.app = app
24+
25+
@classmethod
26+
def setUpClass(cls):
27+
from app.main import get_pipeline
28+
29+
get_pipeline.cache_clear()
30+
31+
def tearDown(self):
32+
if self.old_model_id is not None:
33+
os.environ["MODEL_ID"] = self.old_model_id
34+
else:
35+
del os.environ["MODEL_ID"]
36+
if self.old_task is not None:
37+
os.environ["TASK"] = self.old_task
38+
else:
39+
del os.environ["TASK"]
40+
41+
def test_simple(self):
42+
inputs = "The weather is nice today."
43+
44+
with TestClient(self.app) as client:
45+
response = client.post("/", json={"inputs": inputs})
46+
47+
self.assertEqual(
48+
response.status_code,
49+
200,
50+
)
51+
content = json.loads(response.content)
52+
self.assertEqual(type(content), list)
53+
self.assertEqual(type(content[0]["summary_text"]), str)
54+
55+
with TestClient(self.app) as client:
56+
response = client.post("/", json=inputs)
57+
58+
self.assertEqual(
59+
response.status_code,
60+
200,
61+
)
62+
content = json.loads(response.content)
63+
self.assertEqual(type(content), list)
64+
self.assertEqual(type(content[0]["summary_text"]), str)
65+
66+
def test_malformed_question(self):
67+
with TestClient(self.app) as client:
68+
response = client.post("/", data=b"\xc3\x28")
69+
70+
self.assertEqual(
71+
response.status_code,
72+
400,
73+
)
74+
content = json.loads(response.content)
75+
self.assertEqual(set(content.keys()), {"error"})
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import json
2+
import os
3+
from unittest import TestCase, skipIf
4+
5+
from app.main import ALLOWED_TASKS
6+
from starlette.testclient import TestClient
7+
from tests.test_api import TESTABLE_MODELS
8+
9+
10+
@skipIf(
11+
"text-generation" not in ALLOWED_TASKS,
12+
"text-generation not implemented",
13+
)
14+
class TextGenerationTestCase(TestCase):
15+
def setUp(self):
16+
model_id = TESTABLE_MODELS["text-generation"]
17+
self.old_model_id = os.getenv("MODEL_ID")
18+
self.old_task = os.getenv("TASK")
19+
os.environ["MODEL_ID"] = model_id
20+
os.environ["TASK"] = "text-generation"
21+
from app.main import app
22+
23+
self.app = app
24+
25+
@classmethod
26+
def setUpClass(cls):
27+
from app.main import get_pipeline
28+
29+
get_pipeline.cache_clear()
30+
31+
def tearDown(self):
32+
if self.old_model_id is not None:
33+
os.environ["MODEL_ID"] = self.old_model_id
34+
else:
35+
del os.environ["MODEL_ID"]
36+
if self.old_task is not None:
37+
os.environ["TASK"] = self.old_task
38+
else:
39+
del os.environ["TASK"]
40+
41+
def test_simple(self):
42+
inputs = "The weather is nice today."
43+
44+
with TestClient(self.app) as client:
45+
response = client.post("/", json={"inputs": inputs})
46+
47+
self.assertEqual(
48+
response.status_code,
49+
200,
50+
)
51+
content = json.loads(response.content)
52+
self.assertEqual(type(content), list)
53+
54+
with TestClient(self.app) as client:
55+
response = client.post("/", json=inputs)
56+
57+
self.assertEqual(
58+
response.status_code,
59+
200,
60+
)
61+
content = json.loads(response.content)
62+
self.assertEqual(type(content), list)
63+
self.assertEqual(type(content[0]["generated_text"]), str)
64+
65+
def test_malformed_question(self):
66+
with TestClient(self.app) as client:
67+
response = client.post("/", data=b"\xc3\x28")
68+
69+
self.assertEqual(
70+
response.status_code,
71+
400,
72+
)
73+
content = json.loads(response.content)
74+
self.assertEqual(set(content.keys()), {"error"})

tests/test_dockers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,13 @@ def test_adapter_transformers(self):
136136
self.framework_docker_test(
137137
"adapter_transformers",
138138
"question-answering",
139-
"calpt/adapter-bert-base-squad1",
139+
"AdapterHub/roberta-base-pf-squad",
140+
)
141+
142+
self.framework_docker_test(
143+
"adapter_transformers",
144+
"summarization",
145+
"AdapterHub/facebook-bart-large_sum_xsum_pfeiffer",
140146
)
141147

142148
self.framework_docker_test(
@@ -145,6 +151,12 @@ def test_adapter_transformers(self):
145151
"AdapterHub/roberta-base-pf-sick",
146152
)
147153

154+
self.framework_docker_test(
155+
"adapter_transformers",
156+
"text-generation",
157+
"AdapterHub/gpt2_lm_poem_pfeiffer",
158+
)
159+
148160
self.framework_docker_test(
149161
"adapter_transformers",
150162
"token-classification",

0 commit comments

Comments
 (0)