From dc13e993484cf23c337e93cac9b28e7195dbbbed Mon Sep 17 00:00:00 2001 From: Yohan Na Date: Fri, 30 Aug 2024 15:34:20 +0900 Subject: [PATCH] [MODEL] add Exaone model support (#7819) --- docs/source/models/supported_models.rst | 4 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/exaone.py | 617 ++++++++++++++++++++ vllm/transformers_utils/config.py | 11 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/exaone.py | 190 ++++++ 6 files changed, 820 insertions(+), 5 deletions(-) create mode 100644 vllm/model_executor/models/exaone.py create mode 100644 vllm/transformers_utils/configs/exaone.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 223c68b40766..f727c646b7da 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -51,6 +51,10 @@ Decoder-only Language Models - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - + * - :code:`ExaoneForCausalLM` + - EXAONE-3 + - :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. + - ✅︎ * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 8591c276b001..fc3d4922aea0 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -22,6 +22,7 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py new file mode 100644 index 000000000000..351bc7e67ca0 --- /dev/null +++ b/vllm/model_executor/models/exaone.py @@ -0,0 +1,617 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py +# Copyright 2024 The LG U+ CTO AI Tech Lab. +# Copyright 2021 The LG AI Research EXAONE Lab +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Exaone model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.transformers_utils.configs.exaone import ExaoneConfig +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +class ExaoneGatedMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.c_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class ExaoneAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class ExaoneBlockAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.attention = ExaoneAttention( + config=config, + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=bias, + cache_config=cache_config, + prefix=prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + return self.attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + +class ExaoneDecoderLayer(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.attn = ExaoneBlockAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.mlp = ExaoneGatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.activation_function, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class ExaoneModel(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.wte = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.wte = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.wte = PPMissingLayer() + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: ExaoneDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.h", + ) + if get_pp_group().is_last_rank: + self.ln_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + else: + self.ln_f = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + +class ExaoneForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "c_fc_0", + "c_fc_1", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "out_proj", + "gate_up_proj", + "c_proj", + "wte", + "lm_head", + ] + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "c_fc_0": ("gate_up_proj", 0), + "c_fc_1": ("gate_up_proj", 1), + } + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.transformer = ExaoneModel( + config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model", + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + "residual": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".c_fc_0", 0), + (".gate_up_proj", ".c_fc_1", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.transformer.h[layer_idx], nn.Identity): + layer_self_attn = self.transformer.h[layer_idx].attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index c2276b075c1d..4a03446590fe 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -11,11 +11,11 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - EAGLEConfig, InternVLChatConfig, - JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, RWConfig, - UltravoxConfig) + EAGLEConfig, ExaoneConfig, + InternVLChatConfig, JAISConfig, + MedusaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + RWConfig, UltravoxConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -34,6 +34,7 @@ "mlp_speculator": MLPSpeculatorConfig, "medusa": MedusaConfig, "eagle": EAGLEConfig, + "exaone": ExaoneConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "ultravox": UltravoxConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index dc2fd6a859e3..736878b35ad4 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,6 +1,7 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig +from vllm.transformers_utils.configs.exaone import ExaoneConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. @@ -22,6 +23,7 @@ "JAISConfig", "MedusaConfig", "EAGLEConfig", + "ExaoneConfig", "MLPSpeculatorConfig", "NemotronConfig", "UltravoxConfig", diff --git a/vllm/transformers_utils/configs/exaone.py b/vllm/transformers_utils/configs/exaone.py new file mode 100644 index 000000000000..805b8ad93003 --- /dev/null +++ b/vllm/transformers_utils/configs/exaone.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copied from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/configuration_exaone.py +# Copyright 2021 The LG AI Research EXAONE Lab. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Exaone model configuration""" + +from typing import Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, str] = {} + + +class ExaoneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class: + `~transformers.ExaoneModel`. It is used to instantiate a GPT Lingvo model + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Exaone + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` + and can be used to control the model outputs. Read the documentation from : + class:`~transformers.PretrainedConfig` for more information. + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50257): + Vocabulary size of the GPT Lingvo model. Defines the number of + different tokens that can be represented by the :obj:`inputs_ids` + passed when calling :class:`~transformers.ExaoneModel`. Vocabulary + size of the model. + Defines the different tokens that can be represented by the + `inputs_ids` passed to the forward method of :class: + `~transformers.EXAONEModel`. + hidden_size (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (:obj:`int`, `optional`, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi + Head Attention (MHA), if `num_key_value_heads=1 the model will use + Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed by meanpooling + all the original heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not + specified, will default to `num_attention_heads`. + rotary_pct (`float`, *optional*, defaults to 0.25): + percentage of hidden dimensions to allocate to rotary embeddings + intermediate_size (:obj:`int`, `optional`, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in + the Transformer encoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, + defaults to :obj:`"gelu_new"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, :obj:`"gelu"`, :obj:`"relu"`, + :obj:`"selu"` and :obj:`"gelu_new"` are supported. + embed_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the + embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling + :class:`~transformers.EXAONEModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values + attentions (not used by all models). + Only relevant if ``config.is_decoder=True``. + gradient_checkpointing (:obj:`bool`, `optional`, + defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense + of slower backward pass. + Example:: + + >>> from transformers import ExoneModel, ExaoneConfig + + >>> # Initializing a EXAONE configuration + >>> configuration = ExaoneConfig() + + >>> # Initializing a model from configuration + >>> model = ExoneModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + + model_type = "exaone" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=102400, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + intermediate_size=None, + activation_function="silu", + rotary_pct=0.25, + resid_dropout=0.0, + embed_dropout=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + initializer_range=0.02, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=True, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_layers + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + if intermediate_size: + self.intermediate_size = intermediate_size + else: + self.intermediate_size = hidden_size * 4 + self.activation_function = activation_function + self.resid_dropout = resid_dropout + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rotary_pct = rotary_pct + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.use_logit_cap = kwargs.pop("use_logit_cap", False) + self.ln_no_scale = kwargs.pop("ln_no_scale", False) + self.use_gated = kwargs.pop("use_gated", False) + self.use_emb_norm = kwargs.pop("use_emb_norm", False) + self.use_rotary_pos = kwargs.pop("use_rotary_pos", False) + self.rotary_type = kwargs.pop("rotary_type", None) + self.scaling_factor = kwargs.pop("scaling_factor", 1) + self.use_absolute_pos = kwargs.pop("use_absolute_pos", True) + self.use_extra_logit = kwargs.pop("use_extra_logit", True) + self.rotary_expand_length = kwargs.pop("rotary_expand_length", None) + self.rotary_base = kwargs.pop("rotary_base", 10000.0) + self.use_qkv_fuse = kwargs.pop("use_qkv_fuse", False) + self.rescale_before_lm_head = kwargs.pop("rescale_before_lm_head", + (rotary_pct == 0.25)) + if self.use_rotary_pos: + self.use_absolute_pos = False