From db09329873966a2137cd31ce54c9ef92b6cba981 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 22 May 2024 13:28:20 -0700 Subject: [PATCH] [Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893) The 2nd PR for #4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter). --- benchmarks/benchmark_latency.py | 14 ++-- benchmarks/benchmark_throughput.py | 12 ++- .../kernels/benchmark_paged_attention.py | 10 +-- tests/models/test_fp8.py | 80 ++++++++++++------- vllm/attention/layer.py | 27 ++++++- vllm/config.py | 8 +- vllm/engine/arg_utils.py | 7 +- .../model_executor/layers/quantization/fp8.py | 47 ++++++++++- vllm/model_executor/models/arctic.py | 3 +- vllm/model_executor/models/baichuan.py | 6 +- vllm/model_executor/models/bloom.py | 3 +- vllm/model_executor/models/chatglm.py | 13 ++- vllm/model_executor/models/commandr.py | 13 ++- vllm/model_executor/models/dbrx.py | 13 ++- vllm/model_executor/models/deepseek.py | 3 +- vllm/model_executor/models/falcon.py | 9 ++- vllm/model_executor/models/gemma.py | 3 +- vllm/model_executor/models/gpt2.py | 3 +- vllm/model_executor/models/gpt_bigcode.py | 3 +- vllm/model_executor/models/gpt_j.py | 3 +- vllm/model_executor/models/gpt_neox.py | 3 +- vllm/model_executor/models/internlm2.py | 3 +- vllm/model_executor/models/jais.py | 13 ++- vllm/model_executor/models/llama.py | 32 ++++---- vllm/model_executor/models/minicpm.py | 3 +- vllm/model_executor/models/mixtral.py | 29 +++++-- vllm/model_executor/models/mixtral_quant.py | 15 ++-- vllm/model_executor/models/mpt.py | 3 +- vllm/model_executor/models/olmo.py | 3 +- vllm/model_executor/models/opt.py | 3 +- vllm/model_executor/models/orion.py | 3 +- vllm/model_executor/models/phi.py | 3 +- vllm/model_executor/models/qwen.py | 3 +- vllm/model_executor/models/qwen2.py | 3 +- vllm/model_executor/models/qwen2_moe.py | 3 +- vllm/model_executor/models/stablelm.py | 3 +- vllm/model_executor/models/starcoder2.py | 15 ++-- vllm/model_executor/models/xverse.py | 3 +- vllm/utils.py | 2 + vllm/worker/model_runner.py | 17 ++-- 40 files changed, 284 insertions(+), 158 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index f84e3453947c9..a9657f7859750 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -153,15 +153,13 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='enforce eager mode and disable CUDA graph') parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], - default='auto', - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 41f443968c3c4..7c8cb5ee8cea2 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -323,15 +323,13 @@ def main(args: argparse.Namespace): action="store_true", help="enforce eager execution") parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=["auto", "fp8"], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index ca7967c1ab0d2..fc9621e885dc4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -183,13 +183,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8"], + choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help="Data type for kv cache storage. If 'auto', will use model " + "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " + "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") args = parser.parse_args() print(args) diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 664e951a89f2a..0a5819ea3f054 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -16,31 +16,55 @@ MAX_MODEL_LEN = 1024 MODELS = [ - "nm-testing/Meta-Llama-3-8B-Instruct-FP8", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", "meta-llama/Meta-Llama-3-8B-Instruct", ] EXPECTED_STRS_MAP = { - "nm-testing/Meta-Llama-3-8B-Instruct-FP8": [ - 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', - ], - "meta-llama/Meta-Llama-3-8B-Instruct": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' - ], + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": { + "auto": [ + 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no' + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system made up of several basic components that work together to enable it to', + 'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk' + ] + }, + "meta-llama/Meta-Llama-3-8B-Instruct": { + "auto": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu' + ] + }, } capability = torch.cuda.get_device_capability() @@ -52,14 +76,14 @@ @pytest.mark.skipif(fp8_not_supported, reason="fp8 is not supported on this GPU type.") @pytest.mark.parametrize("model_name", MODELS) -def test_models( - example_prompts, - model_name, -) -> None: +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +def test_models(example_prompts, model_name, kv_cache_dtype) -> None: model = LLM(model=model_name, max_model_len=MAX_MODEL_LEN, + trust_remote_code=True, enforce_eager=True, - quantization="fp8") + quantization="fp8", + kv_cache_dtype=kv_cache_dtype) tokenizer = AutoTokenizer.from_pretrained(model_name) formatted_prompts = [ @@ -81,8 +105,8 @@ def test_models( generations.append(outputs[0].outputs[0].text) del model - print(generations) - expected_strs = EXPECTED_STRS_MAP[model_name] + print(model_name, kv_cache_dtype, generations) + expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] for i in range(len(example_prompts)): generated_str = generations[i] expected_str = expected_strs[i] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 4299726bdca4b..dc7b3940bc9b7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,6 +7,8 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) class Attention(nn.Module): @@ -30,6 +32,7 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() if cache_config is not None: @@ -40,6 +43,27 @@ def __init__( block_size = 16 if num_kv_heads is None: num_kv_heads = num_heads + + # The default kv_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized kv_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self._kv_scale = 1.0 + quant_method = quant_config.get_quant_method( + self) if quant_config else None + if quant_method is not None: + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # When FP8 quantization is enabled, we make a parameter + # "kv_scale" so that it can be loaded from FP8 checkpoint. + # The kv_scale will then be converted back + # to self._kv_scale in a native float32 value after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() @@ -57,10 +81,9 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, - kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, - kv_scale) + self._kv_scale) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore diff --git a/vllm/config.py b/vllm/config.py index 773655aa6c793..33b49a0fb2284 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -387,14 +387,12 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype == "fp8": + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " - "But it may cause slight accuracy drop without scaling " - "factors. FP8_E5M2 (without scaling) is only supported on " - "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 " - "is instead supported for common inference criteria.") + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b9f0f2efaa177..49cda76233cf9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -195,12 +195,11 @@ def add_cli_args( parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' - 'data type. FP8_E5M2 (without scaling) is only supported on cuda ' - 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria.') + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=nullable_str, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ff996741c1d00..b084b9cee4983 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -8,8 +8,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -58,9 +59,13 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": activation_scheme=activation_scheme) def get_quant_method( - self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]: + self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): return Fp8LinearMethod(self) + if isinstance(layer, Attention): + return Fp8KVCacheMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -251,6 +256,44 @@ def apply(self, return torch.narrow(output, 0, 0, x.shape[0]) +class Fp8KVCacheMethod(QuantizeMethodBase): + """Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """Create "weight" (aka kv_scale) for an attention layer. + + Args: + layer: The layer that is using the QuantizeMethodBase factory. + """ + # Initialize the KV cache scale to 1.0 as the default value. + # If the kv_scale appears in the checkpoint, it will be + # overwritten when loading weights. + layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") + + def process_weights_after_loading(self, layer: Module) -> None: + # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + if layer.kv_cache_dtype != "auto": + kv_scale = layer.kv_scale.to("cpu").tolist() + if not isinstance(kv_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + layer._kv_scale = kv_scale + if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " + "cause accuracy issues. Please make sure kv-cache scaling " + "factor is available in the fp8 checkpoint.") + del layer.kv_scale + + def all_close_1d(x: torch.Tensor) -> bool: assert len(x.shape) == 1 return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index cb99939cbb17a..313762b1353d1 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -268,7 +268,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 58b3405d319d1..babb92e7cdcef 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -154,7 +154,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + quant_config=quant_config) else: self.rotary_emb = get_rope( self.head_dim, @@ -166,7 +167,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index fe2de87b20dc9..a29aee4cffb7d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -111,7 +111,8 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ed65d76f7b5b9..e3a5e43e23e1c 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -86,13 +86,12 @@ def __init__( base=10000 * rope_ratio, is_neox_style=False, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 7354d11f98b15..84786921ce1b4 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -177,13 +177,12 @@ def __init__( rope_scaling=self.rope_scaling, is_neox_style=False, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, self.head_dim), diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 083ddf0159f71..8ff19a2015e0f 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -218,13 +218,12 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 62e04f9649915..8fbda2638aaa3 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -232,7 +232,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index ab9e1994be426..ba707adb03dfe 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -153,7 +153,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + quant_config=quant_config) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -165,13 +166,15 @@ def __init__( self.head_dim, self.inv_norm_factor, num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + quant_config=quant_config) else: self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index d1502b718a773..27dda00b66af4 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -157,7 +157,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 0deaa58ed9eb5..cc83f6eb6d94d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -75,7 +75,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index c20fb3230c394..f488ef40039c0 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -88,7 +88,8 @@ def __init__( self.head_dim, scale=self.scale, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 5f4d8ec3d3a7a..47fd5788a4c35 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -88,7 +88,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_size, scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index dcb52ff666c95..eb0fcc8f26a58 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -89,7 +89,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_size, scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 65f7ddb8b082c..e75c567f589c8 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -117,7 +117,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index df30fd1ba0a37..869b8fc91fd64 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -105,13 +105,12 @@ def __init__( head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = Attention( - self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f2996c240aaf4..23141124e69e1 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -47,7 +47,7 @@ default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from vllm.utils import is_hip +from vllm.utils import is_hip, print_warning_once class LlamaMLP(nn.Module): @@ -119,15 +119,6 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - # This will be overwritten by model initialization if we are using it. - # N.B. currently we only support per tensor scalar scaling factors - # & only applicable to ROCm (AMD GPU). - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - self.kv_scale = 1.0 - self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -155,7 +146,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=sliding_window, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -167,8 +159,7 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata, - self.kv_scale) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -421,6 +412,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + f"Found kv scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_kv_scale_name}). kv-scale is " + "not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -445,7 +449,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.kv_scale = scaling_factor + layer_self_attn.attn._kv_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 0b85cf1c94795..59fbf8e1b35f2 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -236,7 +236,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e3ac33e0452fe..ea95cf7380d54 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -308,14 +308,13 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -581,6 +580,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index ee2626b1c1aa2..9b99ff729aadd 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -213,14 +213,13 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 716ac51cde94d..5f9e4d86f3cd8 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -110,7 +110,8 @@ def __init__( scaling, alibi_slopes=alibi_slopes, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 69f23bbfb5d0a..39270f71ec46f 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -96,7 +96,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) # Attention output projection. self.o_proj = RowParallelLinear( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index d241756e50f4a..4bf59105dbabb 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -91,7 +91,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 59cd42e31b374..133a10e6bb3e8 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -121,7 +121,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 193a29d20c894..c8e61735a9bb6 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -110,7 +110,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_size, scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d158846a3a1f5..d22ea6b79de0f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -106,7 +106,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 97ab6168c3230..ec203c3b9001a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -141,7 +141,8 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index a0d3b0406ef4a..564536f2dd248 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -241,7 +241,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 8b4a5507feade..a6ed3800bed0f 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -127,7 +127,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_key_value_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 3c19d63276a77..91ffd0861c39d 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -97,14 +97,13 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 6ef230a8ebbca..dda13d83f89a3 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -135,7 +135,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=sliding_window, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/utils.py b/vllm/utils.py index bd47ab055b7b5..f4f027ce70e37 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -31,6 +31,8 @@ "bfloat16": torch.bfloat16, "float": torch.float, "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, } diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e264fede0ee64..9720363ac300e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,5 @@ import time +import warnings from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np @@ -168,11 +169,21 @@ def load_model(self) -> None: self.model = self.lora_manager.create_lora_manager(self.model) if self.kv_cache_dtype == "fp8" and is_hip(): - # Currently scaled KV cache is only enabled on ROCm + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. if self.model_config.quantization_param_path is not None: if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is " + "deprecated and will be removed. Please include " + "kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2) self.model.load_kv_cache_scales( self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", + self.model_config.quantization_param_path) else: raise RuntimeError( "Using FP8 KV cache and scaling factors provided but " @@ -183,10 +194,6 @@ def load_model(self) -> None: "Using FP8 KV cache but no scaling factors " "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - elif self.model_config.quantization_param_path is not None: - logger.warning("KV cache scaling factors provided, " - "but the KV cache data type is not FP8. " - "KV cache scaling factors will not be used.") def save_sharded_state( self,