Skip to content

Commit

Permalink
Add support for Qwen2 models (#746)
Browse files Browse the repository at this point in the history
* feat(export): allow to set fuse_qkv=False for decoders

* feat(decoder): add Qwen2 modeling code

* feat(decoder): allow export from local class

* refactor(exporters): isolate decoder export configs

These configs depend on the availability of transformers_neuronx, so
it makes sense to only register them if the package is available.

* feat(export): add support for Qwen2

* test(decoder): add Qwen2

* test(tgi): add Qwen2 tests

* perf: add QwenMath accuracy example results

* perf(tgi): add Qwen2.5-7b performances

* review: add type annotations

* review: add comments in Qwen2ForSampling
  • Loading branch information
dacorvo authored Dec 5, 2024
1 parent 93aabc6 commit 6ba3db9
Show file tree
Hide file tree
Showing 22 changed files with 706 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
MODEL_ID='Qwen/Qwen2.5-7B-Instruct'
HF_AUTO_CAST_TYPE='bf16'
MAX_BATCH_SIZE=32
MAX_INPUT_TOKENS=4000
MAX_TOTAL_TOKENS=4096
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
version: '3.7'

services:
tgi-1:
image: neuronx-tgi:latest
ports:
- "8081:8081"
environment:
- PORT=8081
- MODEL_ID=${MODEL_ID}
- HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE}
- HF_NUM_CORES=8
- MAX_BATCH_SIZE=${MAX_BATCH_SIZE}
- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
- MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS}
- MAX_CONCURRENT_REQUESTS=512
- HF_TOKEN=${HF_TOKEN}
devices:
- "/dev/neuron0"
- "/dev/neuron1"
- "/dev/neuron2"
- "/dev/neuron3"

tgi-2:
image: neuronx-tgi:latest
ports:
- "8082:8082"
environment:
- PORT=8082
- MODEL_ID=${MODEL_ID}
- HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE}
- HF_NUM_CORES=8
- MAX_BATCH_SIZE=${MAX_BATCH_SIZE}
- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
- MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS}
- MAX_CONCURRENT_REQUESTS=512
- HF_TOKEN=${HF_TOKEN}
devices:
- "/dev/neuron4"
- "/dev/neuron5"
- "/dev/neuron6"
- "/dev/neuron7"

tgi-3:
image: neuronx-tgi:latest
ports:
- "8083:8083"
environment:
- PORT=8083
- MODEL_ID=${MODEL_ID}
- HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE}
- HF_NUM_CORES=8
- MAX_BATCH_SIZE=${MAX_BATCH_SIZE}
- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
- MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS}
- MAX_CONCURRENT_REQUESTS=512
- HF_TOKEN=${HF_TOKEN}
devices:
- "/dev/neuron8"
- "/dev/neuron9"
- "/dev/neuron10"
- "/dev/neuron11"

loadbalancer:
image: nginx:alpine
ports:
- "8080:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
depends_on:
- tgi-1
- tgi-2
- tgi-3
deploy:
placement:
constraints: [node.role == manager]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
### Nginx TGI Load Balancer
events {}
http {
upstream tgicluster {
server tgi-1:8081;
server tgi-2:8082;
server tgi-3:8083;
}
server {
listen 80;
location / {
proxy_pass http://tgicluster;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model_id,Date,Input type,Requests per Second,Request Latency (s),Time-to-first-token (ms),Inter Token Latency (ms),Output Token Throughput (t/s)
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,synchronous,0.16124266966166817,6.200973322516994,309.0427423778333,25.97485797662497,36.57662664430473
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@0.55 req/sec,0.49461558754572243,11.853755130606183,361.0207387956522,48.287324351631526,117.7268931509251
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@0.93 req/sec,0.8060968082815412,16.24768308375744,375.21653479718145,67.57783339749032,189.26981548491378
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@1.32 req/sec,1.083945791108799,21.60137509382688,391.8051444567167,90.79909233959562,253.1763846248275
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@1.71 req/sec,1.360321529639815,22.870551178060428,896.7999958553197,94.77224932706102,315.15228174103277
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@2.10 req/sec,1.6004688460192356,27.518067228297394,1464.1120346883934,112.11173711716121,371.45881623077696
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@2.48 req/sec,1.8374073942778475,29.824548766450974,1626.0160196174695,122.31885821081055,423.5491627411547
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@2.87 req/sec,2.0547734036381797,33.20240214091389,2375.624083671249,133.96148232126046,472.60651633847726
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@3.64 req/sec,2.0780593811446972,40.66464872033365,8195.832600516658,138.66332340499426,486.5759282406912
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@3.26 req/sec,2.116392255309062,36.28229375148383,4732.812661824264,134.96114258046998,494.68585904605914
Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,throughput,3.6428876172319473,25.543468462793452,7593.348495583786,77.56031301334828,844.191272561698
7 changes: 7 additions & 0 deletions benchmark/text-generation/accuracy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,10 @@ You can evaluate:
| | |none | 0|acc_norm ||0.7581|± |0.0043|
|lambada_openai| 1|none | 0|acc ||0.7173|± |0.0063|
| | |none | 0|perplexity ||3.1102|± |0.0769|

### Qwen/Qwen2.5-Math-7B-Instruct

|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match||0.8878|± |0.0087|
| | |strict-match | 5|exact_match||0.8870|± |0.0087|
20 changes: 15 additions & 5 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,17 +445,23 @@ class NeuronDecoderConfig(NeuronConfig):
NEURONX_CLASS = None
CONTINUOUS_BATCHING = False
ATTENTION_lAYOUT = "HSB"
FUSE_QKV = True

def __init__(self, task: str):
if not is_transformers_neuronx_available():
raise ModuleNotFoundError(
"The mandatory transformers-neuronx package is missing. Please install optimum[neuronx]."
)
module_name, class_name = self.NEURONX_CLASS.rsplit(".", maxsplit=1)
module = importlib.import_module(f"transformers_neuronx.{module_name}")
self._neuronx_class = getattr(module, class_name, None)
if self._neuronx_class is None:
raise ImportError(f"{class_name} not found in {module_name}. Please check transformers-neuronx version.")
if isinstance(self.NEURONX_CLASS, type):
self._neuronx_class = self.NEURONX_CLASS
else:
module_name, class_name = self.NEURONX_CLASS.rsplit(".", maxsplit=1)
module = importlib.import_module(f"transformers_neuronx.{module_name}")
self._neuronx_class = getattr(module, class_name, None)
if self._neuronx_class is None:
raise ImportError(
f"{class_name} not found in {module_name}. Please check transformers-neuronx version."
)

@property
def neuronx_class(self):
Expand All @@ -468,3 +474,7 @@ def continuous_batching(self):
@property
def attention_layout(self):
return self.ATTENTION_lAYOUT

@property
def fuse_qkv(self):
return self.FUSE_QKV
22 changes: 22 additions & 0 deletions optimum/exporters/neuron/model_configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model specific Neuron configurations."""

from ....neuron.utils.import_utils import is_transformers_neuronx_available
from .traced_configs import *


if is_transformers_neuronx_available():
from .decoder_configs import *
65 changes: 65 additions & 0 deletions optimum/exporters/neuron/model_configs/decoder_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neuron export configurations for models using transformers_neuronx."""


from optimum.exporters.tasks import TasksManager

from ....neuron.models.qwen2.model import Qwen2ForSampling
from ..config import TextNeuronDecoderConfig


register_in_tasks_manager = TasksManager.create_register("neuron")


@register_in_tasks_manager("gpt2", "text-generation")
class GPT2NeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "gpt2.model.GPT2ForSampling"


@register_in_tasks_manager("llama", "text-generation")
class LLamaNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "llama.model.LlamaForSampling"
CONTINUOUS_BATCHING = True
ATTENTION_lAYOUT = "BSH"


@register_in_tasks_manager("opt", "text-generation")
class OPTNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "opt.model.OPTForSampling"


@register_in_tasks_manager("bloom", "text-generation")
class BloomNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "bloom.model.BloomForSampling"


@register_in_tasks_manager("mistral", "text-generation")
class MistralNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "mistral.model.MistralForSampling"
CONTINUOUS_BATCHING = True


@register_in_tasks_manager("mixtral", "text-generation")
class MixtralNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "mixtral.model.MixtralForSampling"
CONTINUOUS_BATCHING = False


@register_in_tasks_manager("qwen2", "text-generation")
class Qwen2NeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = Qwen2ForSampling
CONTINUOUS_BATCHING = True
FUSE_QKV = False
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,8 @@

import torch

from ...neuron.distributed import ParallelizersManager
from ...neuron.utils import (
ASTDummyAudioInputGenerator,
DummyBeamValuesGenerator,
DummyControNetInputGenerator,
DummyMaskedPosGenerator,
is_neuronx_distributed_available,
)
from ...utils import (
from optimum.exporters.tasks import TasksManager
from optimum.utils import (
DummyInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummyTextInputGenerator,
Expand All @@ -42,16 +35,23 @@
NormalizedVisionConfig,
is_diffusers_available,
)
from ..tasks import TasksManager
from .config import (

from ....neuron.distributed import ParallelizersManager
from ....neuron.utils import (
ASTDummyAudioInputGenerator,
DummyBeamValuesGenerator,
DummyControNetInputGenerator,
DummyMaskedPosGenerator,
is_neuronx_distributed_available,
)
from ..config import (
AudioNeuronConfig,
TextAndVisionNeuronConfig,
TextEncoderNeuronConfig,
TextNeuronDecoderConfig,
TextSeq2SeqNeuronConfig,
VisionNeuronConfig,
)
from .model_wrappers import (
from ..model_wrappers import (
ControlNetNeuronWrapper,
NoCacheModelWrapper,
SentenceTransformersCLIPNeuronWrapper,
Expand Down Expand Up @@ -768,18 +768,6 @@ def patch_model_for_export(
return super().patch_model_for_export(model=model, dummy_inputs=dummy_inputs, forward_with_tuple=True)


@register_in_tasks_manager("gpt2", "text-generation")
class GPT2NeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "gpt2.model.GPT2ForSampling"


@register_in_tasks_manager("llama", "text-generation")
class LLamaNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "llama.model.LlamaForSampling"
CONTINUOUS_BATCHING = True
ATTENTION_lAYOUT = "BSH"


@register_in_tasks_manager("t5-encoder", "text2text-generation")
class T5EncoderNeuronConfig(TextSeq2SeqNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
Expand Down Expand Up @@ -980,25 +968,3 @@ def generate_io_aliases(self, decoder):
aliases[decoder.past_key_values_ca[i]] = len(decoder.past_key_values_sa) + i + num_outputs_from_trace

return aliases


@register_in_tasks_manager("opt", "text-generation")
class OPTNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "opt.model.OPTForSampling"


@register_in_tasks_manager("bloom", "text-generation")
class BloomNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "bloom.model.BloomForSampling"


@register_in_tasks_manager("mistral", "text-generation")
class MistralNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "mistral.model.MistralForSampling"
CONTINUOUS_BATCHING = True


@register_in_tasks_manager("mixtral", "text-generation")
class MixtralNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "mixtral.model.MixtralForSampling"
CONTINUOUS_BATCHING = False
6 changes: 4 additions & 2 deletions optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,14 @@ def __init__(
tnx_kwargs["neuron_config"] = NeuronConfig(
continuous_batching=ContinuousBatchingConfig(batch_size_for_shared_caches=batch_size),
attention_layout=exporter.attention_layout,
fuse_qkv=True,
fuse_qkv=exporter.fuse_qkv,
)
tnx_kwargs["n_positions"] = [sequence_length]
tnx_kwargs["context_length_estimate"] = [sequence_length]
else:
tnx_kwargs["neuron_config"] = NeuronConfig(attention_layout=exporter.attention_layout, fuse_qkv=True)
tnx_kwargs["neuron_config"] = NeuronConfig(
attention_layout=exporter.attention_layout, fuse_qkv=exporter.fuse_qkv
)
tnx_kwargs["n_positions"] = sequence_length

# Instantiate neuronx model
Expand Down
14 changes: 14 additions & 0 deletions optimum/neuron/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 changes: 14 additions & 0 deletions optimum/neuron/models/qwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 6ba3db9

Please sign in to comment.