diff --git a/benchmark/text-generation-inference/performance/qwen2.5-7b/.env b/benchmark/text-generation-inference/performance/qwen2.5-7b/.env new file mode 100644 index 000000000..925fe50eb --- /dev/null +++ b/benchmark/text-generation-inference/performance/qwen2.5-7b/.env @@ -0,0 +1,5 @@ +MODEL_ID='Qwen/Qwen2.5-7B-Instruct' +HF_AUTO_CAST_TYPE='bf16' +MAX_BATCH_SIZE=32 +MAX_INPUT_TOKENS=4000 +MAX_TOTAL_TOKENS=4096 diff --git a/benchmark/text-generation-inference/performance/qwen2.5-7b/docker-compose.yaml b/benchmark/text-generation-inference/performance/qwen2.5-7b/docker-compose.yaml new file mode 100644 index 000000000..960e93ea8 --- /dev/null +++ b/benchmark/text-generation-inference/performance/qwen2.5-7b/docker-compose.yaml @@ -0,0 +1,76 @@ +version: '3.7' + +services: + tgi-1: + image: neuronx-tgi:latest + ports: + - "8081:8081" + environment: + - PORT=8081 + - MODEL_ID=${MODEL_ID} + - HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE} + - HF_NUM_CORES=8 + - MAX_BATCH_SIZE=${MAX_BATCH_SIZE} + - MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} + - MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} + - MAX_CONCURRENT_REQUESTS=512 + - HF_TOKEN=${HF_TOKEN} + devices: + - "/dev/neuron0" + - "/dev/neuron1" + - "/dev/neuron2" + - "/dev/neuron3" + + tgi-2: + image: neuronx-tgi:latest + ports: + - "8082:8082" + environment: + - PORT=8082 + - MODEL_ID=${MODEL_ID} + - HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE} + - HF_NUM_CORES=8 + - MAX_BATCH_SIZE=${MAX_BATCH_SIZE} + - MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} + - MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} + - MAX_CONCURRENT_REQUESTS=512 + - HF_TOKEN=${HF_TOKEN} + devices: + - "/dev/neuron4" + - "/dev/neuron5" + - "/dev/neuron6" + - "/dev/neuron7" + + tgi-3: + image: neuronx-tgi:latest + ports: + - "8083:8083" + environment: + - PORT=8083 + - MODEL_ID=${MODEL_ID} + - HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE} + - HF_NUM_CORES=8 + - MAX_BATCH_SIZE=${MAX_BATCH_SIZE} + - MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} + - MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} + - MAX_CONCURRENT_REQUESTS=512 + - HF_TOKEN=${HF_TOKEN} + devices: + - "/dev/neuron8" + - "/dev/neuron9" + - "/dev/neuron10" + - "/dev/neuron11" + + loadbalancer: + image: nginx:alpine + ports: + - "8080:80" + volumes: + - ./nginx.conf:/etc/nginx/nginx.conf:ro + depends_on: + - tgi-1 + - tgi-2 + - tgi-3 + deploy: + placement: + constraints: [node.role == manager] diff --git a/benchmark/text-generation-inference/performance/qwen2.5-7b/nginx.conf b/benchmark/text-generation-inference/performance/qwen2.5-7b/nginx.conf new file mode 100644 index 000000000..37a3b8721 --- /dev/null +++ b/benchmark/text-generation-inference/performance/qwen2.5-7b/nginx.conf @@ -0,0 +1,15 @@ +### Nginx TGI Load Balancer +events {} +http { + upstream tgicluster { + server tgi-1:8081; + server tgi-2:8082; + server tgi-3:8083; + } + server { + listen 80; + location / { + proxy_pass http://tgicluster; + } + } +} diff --git a/benchmark/text-generation-inference/performance/qwen2.5-7b/tgi-results.csv b/benchmark/text-generation-inference/performance/qwen2.5-7b/tgi-results.csv new file mode 100644 index 000000000..17cf6125f --- /dev/null +++ b/benchmark/text-generation-inference/performance/qwen2.5-7b/tgi-results.csv @@ -0,0 +1,12 @@ +model_id,Date,Input type,Requests per Second,Request Latency (s),Time-to-first-token (ms),Inter Token Latency (ms),Output Token Throughput (t/s) +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,synchronous,0.16124266966166817,6.200973322516994,309.0427423778333,25.97485797662497,36.57662664430473 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@0.55 req/sec,0.49461558754572243,11.853755130606183,361.0207387956522,48.287324351631526,117.7268931509251 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@0.93 req/sec,0.8060968082815412,16.24768308375744,375.21653479718145,67.57783339749032,189.26981548491378 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@1.32 req/sec,1.083945791108799,21.60137509382688,391.8051444567167,90.79909233959562,253.1763846248275 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@1.71 req/sec,1.360321529639815,22.870551178060428,896.7999958553197,94.77224932706102,315.15228174103277 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@2.10 req/sec,1.6004688460192356,27.518067228297394,1464.1120346883934,112.11173711716121,371.45881623077696 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@2.48 req/sec,1.8374073942778475,29.824548766450974,1626.0160196174695,122.31885821081055,423.5491627411547 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@2.87 req/sec,2.0547734036381797,33.20240214091389,2375.624083671249,133.96148232126046,472.60651633847726 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@3.64 req/sec,2.0780593811446972,40.66464872033365,8195.832600516658,138.66332340499426,486.5759282406912 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,asynchronous@3.26 req/sec,2.116392255309062,36.28229375148383,4732.812661824264,134.96114258046998,494.68585904605914 +Qwen_Qwen2.5-7B-Instruct,2024-12-03-14-55-31,throughput,3.6428876172319473,25.543468462793452,7593.348495583786,77.56031301334828,844.191272561698 diff --git a/benchmark/text-generation/accuracy/README.md b/benchmark/text-generation/accuracy/README.md index 0c3a166f9..cef3e3fd2 100644 --- a/benchmark/text-generation/accuracy/README.md +++ b/benchmark/text-generation/accuracy/README.md @@ -21,3 +21,10 @@ You can evaluate: | | |none | 0|acc_norm |↑ |0.7581|± |0.0043| |lambada_openai| 1|none | 0|acc |↑ |0.7173|± |0.0063| | | |none | 0|perplexity |↓ |3.1102|± |0.0769| + +### Qwen/Qwen2.5-Math-7B-Instruct + +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8878|± |0.0087| +| | |strict-match | 5|exact_match|↑ |0.8870|± |0.0087| diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py index 9520b2483..948b8e830 100644 --- a/optimum/exporters/neuron/base.py +++ b/optimum/exporters/neuron/base.py @@ -445,17 +445,23 @@ class NeuronDecoderConfig(NeuronConfig): NEURONX_CLASS = None CONTINUOUS_BATCHING = False ATTENTION_lAYOUT = "HSB" + FUSE_QKV = True def __init__(self, task: str): if not is_transformers_neuronx_available(): raise ModuleNotFoundError( "The mandatory transformers-neuronx package is missing. Please install optimum[neuronx]." ) - module_name, class_name = self.NEURONX_CLASS.rsplit(".", maxsplit=1) - module = importlib.import_module(f"transformers_neuronx.{module_name}") - self._neuronx_class = getattr(module, class_name, None) - if self._neuronx_class is None: - raise ImportError(f"{class_name} not found in {module_name}. Please check transformers-neuronx version.") + if isinstance(self.NEURONX_CLASS, type): + self._neuronx_class = self.NEURONX_CLASS + else: + module_name, class_name = self.NEURONX_CLASS.rsplit(".", maxsplit=1) + module = importlib.import_module(f"transformers_neuronx.{module_name}") + self._neuronx_class = getattr(module, class_name, None) + if self._neuronx_class is None: + raise ImportError( + f"{class_name} not found in {module_name}. Please check transformers-neuronx version." + ) @property def neuronx_class(self): @@ -468,3 +474,7 @@ def continuous_batching(self): @property def attention_layout(self): return self.ATTENTION_lAYOUT + + @property + def fuse_qkv(self): + return self.FUSE_QKV diff --git a/optimum/exporters/neuron/model_configs/__init__.py b/optimum/exporters/neuron/model_configs/__init__.py new file mode 100644 index 000000000..875932f43 --- /dev/null +++ b/optimum/exporters/neuron/model_configs/__init__.py @@ -0,0 +1,22 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model specific Neuron configurations.""" + +from ....neuron.utils.import_utils import is_transformers_neuronx_available +from .traced_configs import * + + +if is_transformers_neuronx_available(): + from .decoder_configs import * diff --git a/optimum/exporters/neuron/model_configs/decoder_configs.py b/optimum/exporters/neuron/model_configs/decoder_configs.py new file mode 100644 index 000000000..dd7f01d3b --- /dev/null +++ b/optimum/exporters/neuron/model_configs/decoder_configs.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Neuron export configurations for models using transformers_neuronx.""" + + +from optimum.exporters.tasks import TasksManager + +from ....neuron.models.qwen2.model import Qwen2ForSampling +from ..config import TextNeuronDecoderConfig + + +register_in_tasks_manager = TasksManager.create_register("neuron") + + +@register_in_tasks_manager("gpt2", "text-generation") +class GPT2NeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = "gpt2.model.GPT2ForSampling" + + +@register_in_tasks_manager("llama", "text-generation") +class LLamaNeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = "llama.model.LlamaForSampling" + CONTINUOUS_BATCHING = True + ATTENTION_lAYOUT = "BSH" + + +@register_in_tasks_manager("opt", "text-generation") +class OPTNeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = "opt.model.OPTForSampling" + + +@register_in_tasks_manager("bloom", "text-generation") +class BloomNeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = "bloom.model.BloomForSampling" + + +@register_in_tasks_manager("mistral", "text-generation") +class MistralNeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = "mistral.model.MistralForSampling" + CONTINUOUS_BATCHING = True + + +@register_in_tasks_manager("mixtral", "text-generation") +class MixtralNeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = "mixtral.model.MixtralForSampling" + CONTINUOUS_BATCHING = False + + +@register_in_tasks_manager("qwen2", "text-generation") +class Qwen2NeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = Qwen2ForSampling + CONTINUOUS_BATCHING = True + FUSE_QKV = False diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs/traced_configs.py similarity index 96% rename from optimum/exporters/neuron/model_configs.py rename to optimum/exporters/neuron/model_configs/traced_configs.py index 7a8ffee4e..c936f00f8 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs/traced_configs.py @@ -20,15 +20,8 @@ import torch -from ...neuron.distributed import ParallelizersManager -from ...neuron.utils import ( - ASTDummyAudioInputGenerator, - DummyBeamValuesGenerator, - DummyControNetInputGenerator, - DummyMaskedPosGenerator, - is_neuronx_distributed_available, -) -from ...utils import ( +from optimum.exporters.tasks import TasksManager +from optimum.utils import ( DummyInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummyTextInputGenerator, @@ -42,16 +35,23 @@ NormalizedVisionConfig, is_diffusers_available, ) -from ..tasks import TasksManager -from .config import ( + +from ....neuron.distributed import ParallelizersManager +from ....neuron.utils import ( + ASTDummyAudioInputGenerator, + DummyBeamValuesGenerator, + DummyControNetInputGenerator, + DummyMaskedPosGenerator, + is_neuronx_distributed_available, +) +from ..config import ( AudioNeuronConfig, TextAndVisionNeuronConfig, TextEncoderNeuronConfig, - TextNeuronDecoderConfig, TextSeq2SeqNeuronConfig, VisionNeuronConfig, ) -from .model_wrappers import ( +from ..model_wrappers import ( ControlNetNeuronWrapper, NoCacheModelWrapper, SentenceTransformersCLIPNeuronWrapper, @@ -768,18 +768,6 @@ def patch_model_for_export( return super().patch_model_for_export(model=model, dummy_inputs=dummy_inputs, forward_with_tuple=True) -@register_in_tasks_manager("gpt2", "text-generation") -class GPT2NeuronConfig(TextNeuronDecoderConfig): - NEURONX_CLASS = "gpt2.model.GPT2ForSampling" - - -@register_in_tasks_manager("llama", "text-generation") -class LLamaNeuronConfig(TextNeuronDecoderConfig): - NEURONX_CLASS = "llama.model.LlamaForSampling" - CONTINUOUS_BATCHING = True - ATTENTION_lAYOUT = "BSH" - - @register_in_tasks_manager("t5-encoder", "text2text-generation") class T5EncoderNeuronConfig(TextSeq2SeqNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 @@ -980,25 +968,3 @@ def generate_io_aliases(self, decoder): aliases[decoder.past_key_values_ca[i]] = len(decoder.past_key_values_sa) + i + num_outputs_from_trace return aliases - - -@register_in_tasks_manager("opt", "text-generation") -class OPTNeuronConfig(TextNeuronDecoderConfig): - NEURONX_CLASS = "opt.model.OPTForSampling" - - -@register_in_tasks_manager("bloom", "text-generation") -class BloomNeuronConfig(TextNeuronDecoderConfig): - NEURONX_CLASS = "bloom.model.BloomForSampling" - - -@register_in_tasks_manager("mistral", "text-generation") -class MistralNeuronConfig(TextNeuronDecoderConfig): - NEURONX_CLASS = "mistral.model.MistralForSampling" - CONTINUOUS_BATCHING = True - - -@register_in_tasks_manager("mixtral", "text-generation") -class MixtralNeuronConfig(TextNeuronDecoderConfig): - NEURONX_CLASS = "mixtral.model.MixtralForSampling" - CONTINUOUS_BATCHING = False diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index 246de6d39..e87c8bb75 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -181,12 +181,14 @@ def __init__( tnx_kwargs["neuron_config"] = NeuronConfig( continuous_batching=ContinuousBatchingConfig(batch_size_for_shared_caches=batch_size), attention_layout=exporter.attention_layout, - fuse_qkv=True, + fuse_qkv=exporter.fuse_qkv, ) tnx_kwargs["n_positions"] = [sequence_length] tnx_kwargs["context_length_estimate"] = [sequence_length] else: - tnx_kwargs["neuron_config"] = NeuronConfig(attention_layout=exporter.attention_layout, fuse_qkv=True) + tnx_kwargs["neuron_config"] = NeuronConfig( + attention_layout=exporter.attention_layout, fuse_qkv=exporter.fuse_qkv + ) tnx_kwargs["n_positions"] = sequence_length # Instantiate neuronx model diff --git a/optimum/neuron/models/__init__.py b/optimum/neuron/models/__init__.py new file mode 100644 index 000000000..fdc025786 --- /dev/null +++ b/optimum/neuron/models/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/optimum/neuron/models/qwen2/__init__.py b/optimum/neuron/models/qwen2/__init__.py new file mode 100644 index 000000000..fdc025786 --- /dev/null +++ b/optimum/neuron/models/qwen2/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/optimum/neuron/models/qwen2/config.py b/optimum/neuron/models/qwen2/config.py new file mode 100644 index 000000000..96d33c04a --- /dev/null +++ b/optimum/neuron/models/qwen2/config.py @@ -0,0 +1,27 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import PretrainedConfig +from transformers_neuronx.llama.config import LlamaConfig + + +class Qwen2Config(LlamaConfig): + """The Qwen2 model uses the same configuration as the TnX LLama model""" + + def __init__( + self, config: PretrainedConfig, n_positions: int, batch_size: int, amp: str, tp_degree: int, **kwargs + ): + super().__init__(config, n_positions, batch_size, amp, tp_degree, **kwargs) + self.model_type = "qwen2" diff --git a/optimum/neuron/models/qwen2/model.py b/optimum/neuron/models/qwen2/model.py new file mode 100644 index 000000000..8ee60d9b4 --- /dev/null +++ b/optimum/neuron/models/qwen2/model.py @@ -0,0 +1,298 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import warnings + +import torch +from transformers import PretrainedConfig +from transformers_neuronx import base, bucket, decoder, ops, utils +from transformers_neuronx.config import NeuronConfig +from transformers_neuronx.constants import KV_SHARD_PAD, LAYOUT_HSB +from transformers_neuronx.llama.hlo import LlamaForSamplingNoEmbeddingHlo + +from .config import Qwen2Config +from .modules import Qwen2ForCausalLM + + +class Qwen2ForSampling(base.NeuronModelBase): + """The Qwen2 model is essentially a LLama model with bias in linear projections. + + The implementation in this class is very similar to the one used for Llama in Tnx. + The only differences are: + - the config (Qwen2Config) and base model (Qwen2ForCausalLM) used in __init__, + - the addition of biases parameters when loading weights from the checkpoint model. + """ + + def __init__( + self, + config: PretrainedConfig, + *, + n_positions: int = 2048, + batch_size: int = 1, + amp: str = "f32", + tp_degree: int = 2, + context_length_estimate: int = None, + context_unroll: int = None, + unroll: int = None, + neuron_config: NeuronConfig = None, + prefixed_length: int = 0, + **kwargs, + ): + config = Qwen2Config(config, n_positions, batch_size, amp, tp_degree) + super().__init__(Qwen2ForCausalLM, config) + self.context_pre_hook = None + self.context_hook = None + self.config = config + self.neuron_config = neuron_config if neuron_config else NeuronConfig() + if self.neuron_config.shard_over_sequence: + n_kv_head = self.config.num_key_value_heads + kv_shard_degree = self.config.tp_degree // n_kv_head + assert kv_shard_degree <= KV_SHARD_PAD, "increase kv_shard degree is higher than default 128" + warnings.warn(f"shard over sequence enabled, increasing n_positions {n_positions} by 128") + if isinstance(n_positions, list): + npos = sorted(n_positions) + npos[-1] += KV_SHARD_PAD + else: + npos = n_positions + KV_SHARD_PAD + self.config.n_positions = npos + config.n_positions = npos + n_positions = npos + if self.neuron_config.on_device_generation: + self.neuron_config.on_device_generation.vocab_size = self.config.vocab_size + + self.layers_after_partition = self.neuron_config.auto_layer_partition(config.num_hidden_layers) + self.prefixed_length = prefixed_length + + if context_unroll is None: + context_unroll = len(self.layers_after_partition) + self.context_unroll = context_unroll + + if unroll is None: + unroll = len(self.layers_after_partition) + self.unroll = unroll + + self.token_buckets = bucket.token_sizes(n_positions) + self.context_buckets = bucket.context_sizes(context_length_estimate, self.token_buckets) + # input length should be divisable by tp_degree to activate seq paralle + if neuron_config and neuron_config.sequence_parallel_norm: + for bucket_size in self.context_buckets: + if ( + bucket_size > neuron_config.sequence_parallel_norm_threshold + and bucket_size % self.config.tp_degree != 0 + ): + raise ValueError( + f"Sequence parallel normalization requires the bucket size ({bucket_size}) to be divisible by the tensor parallel degree ({self.config.tp_degree})" + ) + self.window_context_buckets = [] + if prefixed_length: + if prefixed_length not in self.context_buckets: + self.context_buckets.append(prefixed_length) + self.context_buckets = sorted(self.context_buckets) + + self.batch_sizes = bucket.batch_sizes(batch_size) + self.context_batch_sizes = ( + [1] if self.neuron_config and self.neuron_config.continuous_batching else self.batch_sizes + ) + hlo_builder = LlamaForSamplingNoEmbeddingHlo(config, neuron_config=self.neuron_config) + self.decoder_param_set = decoder.DecoderLmHeadForSamplingNoEmbedding( + tp_degree=tp_degree, + n_positions_list=self.token_buckets, + n_active_tokens=1, + batch_size=self.batch_sizes, + attention_head_size=config.attention_head_size, + amp=amp, + num_layers=len(self.layers_after_partition), + n_head=config.num_attention_heads, + n_kv_head=config.num_key_value_heads, + unroll=unroll, + neuron_config=self.neuron_config, + allow_pad=True, + builder=hlo_builder, + ) + self.decoder_lm_head = self.decoder_param_set.init_token_decoder( + unroll=self.unroll, buckets=self.token_buckets, model_obj=self + ) + self.decoder_lm_head_for_context = self.decoder_param_set.init_context_decoder( + unroll=self.context_unroll, buckets=self.context_buckets, model_obj=self + ) + self.decoder_lm_head_for_speculation = {} + self.decoder_lm_head_for_window_context = {} + + def load_weights(self): + self.materialize_embeddings() + ops.init() + + for layer_id, layer in enumerate(self.chkpt_model.model.layers): + if layer_id not in self.layers_after_partition: + continue + layer.materialize() + attn = layer.self_attn + mlp = layer.mlp + if self.neuron_config and self.neuron_config.quant: + is_unit_scale = self.neuron_config.quant.is_unit_scale(layer_id) + else: + is_unit_scale = False + new_layer = self.decoder_lm_head.new_layer(is_unit_scale=is_unit_scale) + new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None) + new_layer.add_attention_query(attn.q_proj.weight.detach().T, attn.q_proj.bias.detach()) + new_layer.add_attention_key(attn.k_proj.weight.detach().T, attn.k_proj.bias.detach()) + new_layer.add_attention_value(attn.v_proj.weight.detach().T, attn.v_proj.bias.detach()) + if self.neuron_config and self.neuron_config.attn_output_transposed: + new_layer.add_attention_output(attn.o_proj.weight.T.detach(), None, sharding=0, transposed=True) + else: + new_layer.add_attention_output(attn.o_proj.weight.detach(), None, sharding=1, transposed=False) + new_layer.add_pre_mlp_layer_norm(layer.post_attention_layernorm.weight.detach(), None) + + # Note: Automatic MLP padding is safe since zeros are *only* introduced to intermediary state + if self.neuron_config.fuse_mlp: + assert all( + getattr(mlp, attr, None) for attr in ["gate_proj", "up_proj"] + ), "fuse_mlp need to have gate and up proj weights" + assert all( + getattr(mlp, attr, None).weight.shape[0] % self.config.tp_degree == 0 + for attr in ["gate_proj", "up_proj"] + ), f" mlp weights are not divisible tp_degree {self.config.tp_degree}" + mlp_in_weight = utils.interleave_mlp( + mlp.gate_proj.weight, mlp.up_proj.weight, tp_degree=self.config.tp_degree, dim=0 + ) + new_layer.add_mlp_input(mlp_in_weight.T.detach(), None) + if self.neuron_config.mlp_out_weight_transpose: + new_layer.add_mlp_output( + mlp.down_proj.weight.T.detach(), + None, + sharding=0, + transposed=True, + ) + else: + new_layer.add_mlp_output( + mlp.down_proj.weight.detach(), + None, + sharding=1, + transposed=False, + ) + else: + new_layer.add_parameter( + mlp.gate_proj.weight.T, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True + ) + new_layer.add_parameter( + mlp.up_proj.weight.T, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True + ) + if self.neuron_config.weight_tiling: + new_layer.add_parameter( + mlp.down_proj.weight.T, sharding=0, allow_pad=True, allow_quantize=True, allow_transform=True + ) + else: + if self.neuron_config.mlp_out_weight_transpose: + new_layer.add_parameter( + mlp.down_proj.weight.T, sharding=0, allow_pad=True, allow_quantize=True + ) + else: + new_layer.add_parameter( + mlp.down_proj.weight, sharding=1, allow_pad=True, allow_quantize=True, out_feature_dim=0 + ) + new_layer.to_neuron() + layer.nullify() + if self.neuron_config.shard_over_sequence: + self.decoder_lm_head.add_pre_layer_parameter(torch.arange(self.config.tp_degree), sharding=0) + # For pipeline parallel, we need to load ln and lm_head for now even if the pipeline stage doesn't compute the, because + # 1) we need the ln_lm_head hlo for pp0 to get the logits shape and dtype + # 2) we don't needs these for intermediate pp stages, but to keep things simple, just include ln_lm_head for all pp stages for now + # 3) to get ln_lm_head hlo, we need to do weight loading and sharding + # 4) this will introduce extra memory allocation, but ln_lm_head i/o tensor is much smaller and we can get rid of it when we can construct hlo in init + ln_f = self.chkpt_model.model.norm + ln_f.materialize() + self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), None) + + lm_head = self.chkpt_model.lm_head + lm_head.materialize() + self.decoder_lm_head.add_lm_head(lm_head.weight.detach().T) + if self.neuron_config.on_device_embedding: + if self.neuron_config.sequence_parallel_norm: + self.decoder_lm_head.add_pre_layer_parameter( + self.chkpt_model.model.embed_tokens.weight, sharding=None, allow_pad=True + ) + else: + self.decoder_lm_head.add_pre_layer_parameter( + self.chkpt_model.model.embed_tokens.weight, sharding=1, allow_pad=True + ) + lm_head.nullify() + + self.decoder_lm_head.to_neuron() + self.init_rest_of_model() + + def materialize_embeddings(self): + # Materialize the embedding to CPU + self.chkpt_model.model.embed_tokens.materialize() + + def init_rest_of_model(self): + # Pipeline sparallel deosn't support executor right now + if not self.neuron_config.is_pp(): + self.decoder_lm_head.use_executor = True + + if self.context_buckets: + for context_length_estimate in self.context_buckets: + for batch_size in self.context_batch_sizes: + model = self.decoder_lm_head.build_weight_shared( + share_caches=True, new=self.decoder_lm_head_for_context[context_length_estimate, batch_size] + ) + # PERF: No latency improvement seen in multi-layer models from executor + # Pipeline parallel deosn't support executor right now + if self.context_unroll == self.config.num_hidden_layers and not self.neuron_config.is_pp(): + model.use_executor = True + self.decoder_lm_head_for_context[context_length_estimate, batch_size] = model + + if self.decoder_lm_head_for_speculation: + for i, k in enumerate(self.decoder_lm_head_for_speculation): + model = self.decoder_lm_head.build_weight_shared( + share_caches=True, + new=self.decoder_lm_head_for_speculation[k], + embed_weight=self.chkpt_model.model.embed_tokens.weight, + ) + self.decoder_lm_head_for_speculation[k] = model + + if self.decoder_lm_head_for_window_context: + for i, k in enumerate(self.decoder_lm_head_for_window_context): + model = self.decoder_lm_head.build_weight_shared( + share_caches=True, new=self.decoder_lm_head_for_window_context[k] + ) + self.decoder_lm_head_for_window_context[k] = model + + def set_prefixed(self, input_ids): + self.prefixed_input_ids = input_ids[:, : self.prefixed_length] + prefixed_length = self.prefixed_length + self.prefixed_length = 0 + self.forward(self.prefixed_input_ids) + self.prefixed_length = prefixed_length + + def preprocess_and_embed(self, input_ids, cache_ids=None, start_ids=None, **kwargs): + padded_inputs, *rst = self._preprocess(input_ids, start_ids=start_ids, cache_ids=cache_ids, **kwargs) + if not self.neuron_config.on_device_embedding: + input_embeddings = self.chkpt_model.model.embed_tokens(padded_inputs) + if self.neuron_config.attention_layout == LAYOUT_HSB: + input_embeddings = input_embeddings.transpose(0, -1).contiguous() + else: + # embedding layer is on device and will be computed as part of self._forward(), so don't compute here + input_embeddings = None + return padded_inputs, input_embeddings, *rst + + def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, input_embeddings=None, **kwargs): + if last_token_id is not None: # preprocess_and_embed() has already been invoked + rst = cache_ids, start_ids, last_token_id + else: # invoke preprocess_and_embed() + input_ids, input_embeddings, *rst = self.preprocess_and_embed(input_ids, cache_ids, start_ids, **kwargs) + # either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding) + inputs = input_embeddings if input_embeddings is not None else input_ids + logits = self._forward(inputs, *rst) + logits = self._postprocess(logits, start_ids=start_ids, **kwargs) + return logits diff --git a/optimum/neuron/models/qwen2/modules.py b/optimum/neuron/models/qwen2/modules.py new file mode 100644 index 000000000..c4ef4a219 --- /dev/null +++ b/optimum/neuron/models/qwen2/modules.py @@ -0,0 +1,85 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from transformers_neuronx import dtypes, module, utils + +from .config import Qwen2Config + + +class Qwen2ForCausalLM(module.PretrainedModel): + + def __init__(self, config: Qwen2Config): + super().__init__() + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.model = Qwen2Model(config) + self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False) + + def get_tied_parameters(self): + return [(self.model.embed_tokens.weight, self.lm_head.weight)] + + def get_base_model(self): + return self.model + + +class Qwen2Model(module.LowMemoryModule): + + def __init__(self, config: Qwen2Config): + super().__init__() + self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) + self.layers = module.LowMemoryModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = Qwen2RMSNorm(config) + + +class Qwen2RMSNorm(module.LowMemoryModule): + + def __init__(self, config: Qwen2Config) -> None: + super().__init__() + self.weight = module.UninitializedParameter() + + +class Qwen2DecoderLayer(module.LowMemoryModule): + + def __init__(self, config: Qwen2Config): + super().__init__() + self.self_attn = Qwen2Attention(config) + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config) + self.post_attention_layernorm = Qwen2RMSNorm(config) + + +class Qwen2Attention(module.LowMemoryModule): + + def __init__(self, config: Qwen2Config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.q_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=True, dtype=dtype) + self.k_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=True, dtype=dtype) + self.v_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=True, dtype=dtype) + self.o_proj = module.LowMemoryLazyLinear(self.hidden_size, bias=False, dtype=dtype) + + +class Qwen2MLP(module.LowMemoryModule): + + def __init__(self, config: Qwen2Config): + super().__init__() + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.gate_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) + self.up_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) + self.down_proj = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype) diff --git a/optimum/neuron/utils/input_generators.py b/optimum/neuron/utils/input_generators.py index c98cb28eb..c3cceddd3 100644 --- a/optimum/neuron/utils/input_generators.py +++ b/optimum/neuron/utils/input_generators.py @@ -18,7 +18,7 @@ import torch -from ...utils import ( +from optimum.utils import ( DTYPE_MAPPER, DummyAudioInputGenerator, DummyInputGenerator, diff --git a/tests/decoder/conftest.py b/tests/decoder/conftest.py index b8346a7fb..60d728945 100644 --- a/tests/decoder/conftest.py +++ b/tests/decoder/conftest.py @@ -33,6 +33,10 @@ "model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B", "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "fp16"}, }, + "qwen2": { + "model_id": "Qwen/Qwen2.5-0.5B", + "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "fp16"}, + }, "mistral": { "model_id": "optimum/mistral-1.1b-testing", "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"}, diff --git a/tests/decoder/test_decoder_export.py b/tests/decoder/test_decoder_export.py index 8a4bb6712..9224ecb22 100644 --- a/tests/decoder/test_decoder_export.py +++ b/tests/decoder/test_decoder_export.py @@ -30,6 +30,7 @@ "mistral": "dacorvo/tiny-random-MistralForCausalLM", "mixtral": "dacorvo/Mixtral-tiny", "opt": "hf-internal-testing/tiny-random-OPTForCausalLM", + "qwen2": "yujiepan/qwen2.5-128k-tiny-random", } diff --git a/text-generation-inference/tests/fixtures/model.py b/text-generation-inference/tests/fixtures/model.py index b1e785308..73f633862 100644 --- a/text-generation-inference/tests/fixtures/model.py +++ b/text-generation-inference/tests/fixtures/model.py @@ -37,6 +37,10 @@ "model_id": "optimum/mistral-1.1b-testing", "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"}, }, + "qwen2": { + "model_id": "Qwen/Qwen2.5-0.5B", + "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "fp16"}, + }, } diff --git a/text-generation-inference/tests/integration/test_generate.py b/text-generation-inference/tests/integration/test_generate.py index da9e5ea93..0f75a82ad 100644 --- a/text-generation-inference/tests/integration/test_generate.py +++ b/text-generation-inference/tests/integration/test_generate.py @@ -24,6 +24,7 @@ async def test_model_single_request(tgi_service): "gpt2": "\n\nDeep learning is a new field of research that has been around for a while", "llama": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use", "mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that", + "qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on", } assert response.generated_text == greedy_expectations[service_name] @@ -48,6 +49,7 @@ async def test_model_single_request(tgi_service): "gpt2": "Deep Learning", "llama": "Deep Learning", "mistral": "Deep learning", + "qwen2": "Deep Learning", } assert sample_expectations[service_name] in response @@ -81,6 +83,7 @@ async def test_model_multiple_requests(tgi_service, generate_load): "gpt2": "\n\nDeep learning is a new field of research that has been around for a while", "llama": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use", "mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that", + "qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on", } expected = expectations[tgi_service.client.service_name] for r in responses: diff --git a/text-generation-inference/tests/server/test_decode.py b/text-generation-inference/tests/server/test_decode.py index 30c6f9151..7b69eae98 100644 --- a/text-generation-inference/tests/server/test_decode.py +++ b/text-generation-inference/tests/server/test_decode.py @@ -35,14 +35,19 @@ def _test_decode(config_name, generator, do_sample): assert output.generated_tokens == max_new_tokens assert output.finish_reason == 0 if do_sample: - expected_text = {"gpt2": " The sun was set", "llama": "George Orwell, 1984", "mistral": "The sky was"}[ - config_name - ] + expected_text = { + "gpt2": " The sun was set", + "llama": "George Orwell, 1984", + "mistral": "The sky was", + "qwen2": " A young woman with", + }[config_name] assert expected_text in output.text else: + print(output.text) expected_text = { "gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going', "llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story", "mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.", + "qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a", }[config_name] assert output.text == expected_text diff --git a/text-generation-inference/tests/server/test_prefill.py b/text-generation-inference/tests/server/test_prefill.py index c567feaec..7c50fd6bf 100644 --- a/text-generation-inference/tests/server/test_prefill.py +++ b/text-generation-inference/tests/server/test_prefill.py @@ -34,9 +34,19 @@ def _test_prefill(config_name, generator, batch_size, do_sample): assert next_batch.max_tokens == batch_size * max_length assert len(generations) == batch_size if do_sample: - expectations = {"gpt2": [383, " The"], "llama": [10058, " George"], "mistral": [450, " The"]}[config_name] + expectations = { + "gpt2": [383, " The"], + "llama": [10058, " George"], + "mistral": [450, " The"], + "qwen2": [362, " A"], + }[config_name] else: - expectations = {"gpt2": [198, "\n"], "llama": [10058, " George"], "mistral": [13, "\n"]}[config_name] + expectations = { + "gpt2": [198, "\n"], + "llama": [10058, " George"], + "mistral": [13, "\n"], + "qwen2": [358, " I"], + }[config_name] for g in generations: tokens = g.tokens assert tokens.ids[0] == expectations[0] @@ -69,6 +79,7 @@ def test_prefill_truncate(neuron_model_config): "gpt2": [" He", " He", "\n", " He"], "llama": [" —", " The", " He", " He"], "mistral": [" He", "\n", " He", " He"], + "qwen2": [" He", " The", " He", " He"], }[config_name] for i, g in enumerate(generations): tokens = g.tokens