Skip to content

Commit

Permalink
merged batch size variables of huggingface runtime (#1058)
Browse files Browse the repository at this point in the history
  • Loading branch information
saeid93 authored Mar 30, 2023
1 parent 3748b37 commit e017f23
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 122 deletions.
12 changes: 6 additions & 6 deletions docs/examples/huggingface/README.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,12 @@
" \"name\": \"transformer\",\n",
" \"implementation\": \"mlserver_huggingface.HuggingFaceRuntime\",\n",
" \"parallel_workers\": 0,\n",
" \"max_batch_size\": 128,\n",
" \"max_batch_time\": 1,\n",
" \"parameters\": {\n",
" \"extra\": {\n",
" \"task\": \"text-generation\",\n",
" \"device\": -1,\n",
" \"batch_size\": 128\n",
" \"device\": -1\n",
" }\n",
" }\n",
"}"
Expand Down Expand Up @@ -566,8 +567,7 @@
" \"parameters\": {\n",
" \"extra\": {\n",
" \"task\": \"text-generation\",\n",
" \"device\": 0,\n",
" \"batch_size\": 128\n",
" \"device\": 0\n",
" }\n",
" }\n",
"}"
Expand Down Expand Up @@ -649,14 +649,14 @@
"{\n",
" \"name\": \"transformer\",\n",
" \"implementation\": \"mlserver_huggingface.HuggingFaceRuntime\",\n",
" \"parallel_workers\": 0,\n",
" \"max_batch_size\": 128,\n",
" \"max_batch_time\": 1,\n",
" \"parameters\": {\n",
" \"extra\": {\n",
" \"task\": \"text-generation\",\n",
" \"pretrained_model\": \"distilgpt2\",\n",
" \"device\": 0,\n",
" \"batch_size\": 128\n",
" \"device\": 0\n",
" }\n",
" }\n",
"}"
Expand Down
109 changes: 6 additions & 103 deletions docs/examples/huggingface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ We will show how to add share a task
}
```

Overwriting ./model-settings.json


Now that we have our config in-place, we can start the server by running `mlserver start .`. This needs to either be ran from the same directory where our config files are or pointing to the folder where they are.

```shell
Expand Down Expand Up @@ -67,21 +64,6 @@ inference_request = {
requests.post("http://localhost:8080/v2/models/transformer/infer", json=inference_request).json()
```




{'model_name': 'transformer',
'model_version': None,
'id': '9b24304e-730f-4a98-bfde-8949851388a9',
'parameters': None,
'outputs': [{'name': 'output',
'shape': [1],
'datatype': 'BYTES',
'parameters': None,
'data': ['[{"generated_text": "this is a test-case where you\'re checking if someone\'s going to have an encrypted file that they like to open, or whether their file has a hidden contents if their file is not opened. If it\'s the same file, when all the"}]']}]}



### Using Optimum Optimized Models

We can also leverage the Optimum library that allows us to access quantized and optimized models.
Expand All @@ -105,9 +87,6 @@ We can download pretrained optimized models from the hub if available by enablin
}
```

Overwriting ./model-settings.json


Once again, you are able to run the model using the MLServer CLI. As before this needs to either be ran from the same directory where our config files are or pointing to the folder where they are.

```shell
Expand All @@ -134,21 +113,6 @@ inference_request = {
requests.post("http://localhost:8080/v2/models/transformer/infer", json=inference_request).json()
```




{'model_name': 'transformer',
'model_version': None,
'id': '296ea44e-7696-4584-af5a-148a7083b2e7',
'parameters': None,
'outputs': [{'name': 'output',
'shape': [1],
'datatype': 'BYTES',
'parameters': None,
'data': ['[{"generated_text": "this is a test that allows us to define the value type, and a function is defined directly with these variables.\\n\\n\\nThe function is defined for a parameter with type\\nIn this example,\\nif you pass a message function like\\ntype"}]']}]}



## Testing Supported Tasks

We can support multiple other transformers other than just text generation, below includes examples for a few other tasks supported.
Expand All @@ -172,9 +136,6 @@ We can support multiple other transformers other than just text generation, belo
}
```

Overwriting ./model-settings.json


Once again, you are able to run the model using the MLServer CLI.

```shell
Expand Down Expand Up @@ -203,21 +164,6 @@ inference_request = {
requests.post("http://localhost:8080/v2/models/transformer/infer", json=inference_request).json()
```




{'model_name': 'gpt2-model',
'model_version': None,
'id': '204ad4e7-79ea-40b4-8efb-aed16dedf7ed',
'parameters': None,
'outputs': [{'name': 'output',
'shape': [1],
'datatype': 'BYTES',
'parameters': None,
'data': ['{"score": 0.9869922995567322, "start": 12, "end": 18, "answer": "Seldon"}']}]}



### Sentiment Analysis


Expand All @@ -235,9 +181,6 @@ requests.post("http://localhost:8080/v2/models/transformer/infer", json=inferenc
}
```

Overwriting ./model-settings.json


Once again, you are able to run the model using the MLServer CLI.

```shell
Expand All @@ -260,21 +203,6 @@ inference_request = {
requests.post("http://localhost:8080/v2/models/transformer/infer", json=inference_request).json()
```




{'model_name': 'transformer',
'model_version': None,
'id': '463ceddb-f426-4815-9c46-9fa9fc5272b1',
'parameters': None,
'outputs': [{'name': 'output',
'shape': [1],
'datatype': 'BYTES',
'parameters': None,
'data': ['[{"label": "NEGATIVE", "score": 0.9996137022972107}]']}]}



## GPU Acceleration

We can also evaluate GPU acceleration, we can test the speed on CPU vs GPU using the following parameters
Expand All @@ -290,19 +218,17 @@ We first test the time taken with the device=-1 which configures CPU by default
"name": "transformer",
"implementation": "mlserver_huggingface.HuggingFaceRuntime",
"parallel_workers": 0,
"max_batch_size": 128,
"max_batch_time": 1,
"parameters": {
"extra": {
"task": "text-generation",
"device": -1,
"batch_size": 128
"device": -1
}
}
}
```

Overwriting ./model-settings.json


Once again, you are able to run the model using the MLServer CLI.

```shell
Expand Down Expand Up @@ -331,9 +257,6 @@ requests.post("http://localhost:8080/v2/models/transformer/infer", json=inferenc
print(f"Elapsed time: {time.monotonic() - start_time}")
```

Elapsed time: 81.57849169999827


We can see that it takes 81 seconds which is 8 times longer than the gpu example below.

### Testing with GPU
Expand All @@ -352,16 +275,12 @@ Now we'll run the benchmark with GPU configured, which we can do by setting `dev
"parameters": {
"extra": {
"task": "text-generation",
"device": 0,
"batch_size": 128
"device": 0
}
}
}
```

Overwriting ./model-settings.json



```python
inference_request = {
Expand All @@ -384,9 +303,6 @@ requests.post("http://localhost:8080/v2/models/transformer/infer", json=inferenc
print(f"Elapsed time: {time.monotonic() - start_time}")
```

Elapsed time: 11.27933280000434


We can see that the elapsed time is 8 times less than the CPU version!

### Adaptive Batching with GPU
Expand All @@ -403,22 +319,19 @@ We will also configure `max_batch_time` which specifies` the maximum amount of t
{
"name": "transformer",
"implementation": "mlserver_huggingface.HuggingFaceRuntime",
"parallel_workers": 0,
"max_batch_size": 128,
"max_batch_time": 1,
"parameters": {
"extra": {
"task": "text-generation",
"pretrained_model": "distilgpt2",
"device": 0,
"batch_size": 128
"device": 0
}
}
}
```

Overwriting ./model-settings.json


In order to achieve the throughput required of 50 requests per second, we will use the tool `vegeta` which performs load testing.

We can now see that we are able to see that the requests are batched and we receive 100% success eventhough the requests are sent one-by-one.
Expand All @@ -438,16 +351,6 @@ jq -ncM '{"method": "POST", "header": {"Content-Type": ["application/json"] }, "
-type=text
```

Requests [total, rate, throughput] 150, 50.34, 22.28
Duration [total, attack, wait] 6.732s, 2.98s, 3.753s
Latencies [min, mean, 50, 90, 95, 99, max] 1.975s, 3.168s, 3.22s, 4.065s, 4.183s, 4.299s, 4.318s
Bytes In [total, mean] 60978, 406.52
Bytes Out [total, mean] 12300, 82.00
Success [ratio] 100.00%
Status Codes [code:count] 200:150
Error Set:



```python

Expand Down
13 changes: 8 additions & 5 deletions runtimes/huggingface/mlserver_huggingface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from pydantic import BaseSettings
from mlserver.errors import MLServerError
from mlserver.settings import ModelSettings

from transformers.pipelines import pipeline
from transformers.pipelines.base import Pipeline
Expand Down Expand Up @@ -54,7 +55,6 @@ class Config:
pretrained_tokenizer: Optional[str] = None
optimum_model: bool = False
device: int = -1
batch_size: Optional[int] = None

@property
def task_name(self):
Expand Down Expand Up @@ -107,7 +107,9 @@ def parse_parameters_from_env() -> Dict:
return parsed_parameters


def load_pipeline_from_settings(hf_settings: HuggingFaceSettings) -> Pipeline:
def load_pipeline_from_settings(
hf_settings: HuggingFaceSettings, settings: ModelSettings
) -> Pipeline:
"""
TODO
"""
Expand All @@ -131,16 +133,17 @@ def load_pipeline_from_settings(hf_settings: HuggingFaceSettings) -> Pipeline:
# https://github.com/huggingface/optimum/issues/191
device = -1

batch_size = 1 if settings.max_batch_size == 0 else settings.max_batch_size
pp = pipeline(
hf_settings.task_name,
model=model,
tokenizer=tokenizer,
device=device,
batch_size=hf_settings.batch_size,
batch_size=batch_size,
)

# If batch_size > 0 we need to ensure tokens are padded
if hf_settings.batch_size:
# If max_batch_size > 0 we need to ensure tokens are padded
if settings.max_batch_size:
pp.tokenizer.pad_token_id = [str(pp.model.config.eos_token_id)] # type: ignore

return pp
Expand Down
13 changes: 5 additions & 8 deletions runtimes/huggingface/mlserver_huggingface/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,6 @@ def __init__(self, settings: ModelSettings):
),
)

if settings.max_batch_size != self.hf_settings.batch_size:
logger.warning(
f"hf batch_size: {self.hf_settings.batch_size} is different "
f"from MLServer max_batch_size: {settings.max_batch_size}"
)

super().__init__(settings)

async def load(self) -> bool:
Expand All @@ -70,11 +64,14 @@ async def load(self) -> bool:
print(self.hf_settings.task_name)
print("loading model...")
await asyncio.get_running_loop().run_in_executor(
None, load_pipeline_from_settings, self.hf_settings
None,
load_pipeline_from_settings,
self.hf_settings,
self.settings,
)
print("(re)loading model...")
# Now we load the cached model which should not block asyncio
self._model = load_pipeline_from_settings(self.hf_settings)
self._model = load_pipeline_from_settings(self.hf_settings, self.settings)
self._merge_metadata()
print("model has been loaded!")
self.ready = True
Expand Down

0 comments on commit e017f23

Please sign in to comment.