From b3989cf75376749d6163217d4fa9285551b4cfc6 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 5 Mar 2024 18:29:31 -0800 Subject: [PATCH] align to ipex llm ops Signed-off-by: Wang, Yi A --- .../utils/flash_attn.py | 14 +++++--- server/text_generation_server/utils/layers.py | 33 +++++++++++++------ .../utils/paged_attention.py | 19 ++++++++--- 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 7333455e5c9..5ca4ca31cd7 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -4,7 +4,11 @@ from loguru import logger import math -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") @@ -12,6 +16,9 @@ HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_ROCM = False +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: if not torch.cuda.is_available(): raise ImportError("CUDA is not available") @@ -90,7 +97,7 @@ def attention( raise ValueError( f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." ) - return torch.xpu.varlen_fwd( + return ipex.llm.modules.VarlenAttention.apply( q, k, v, @@ -104,10 +111,9 @@ def attention( False, True, False, - None + None, ) - if HAS_FLASH_ATTN_V2_CUDA: return flash_attn_2_cuda.varlen_fwd( q, diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7f10a4f9e91..73ece5ae7ff 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -18,9 +18,16 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) from text_generation_server.utils.log import log_once +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex + HAS_AWQ = True try: from text_generation_server.utils.awq.quantize.qmodule import WQLinear @@ -646,7 +653,13 @@ def forward(self, hidden_states, residual=None): if residual is not None: hidden_states += residual residual = hidden_states - out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps) + out = ipex.llm.modules.FastLayerNorm.apply( + hidden_states, + self.normalized_shape, + self.eps, + self.weight, + self.bias, + ) return out, residual elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if residual is not None: @@ -698,8 +711,11 @@ def forward(self, hidden_states, residual=None): if residual is not None: hidden_states += residual residual = hidden_states - out = torch.ops.torch_ipex.rms_norm( - hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon + out = ipex.llm.modules.RMSNorm.apply( + hidden_states, + [hidden_states.size(-1)], + self.weight, + self.variance_epsilon, ) return out[0], residual elif hidden_states.shape[-1] > 8192: @@ -829,15 +845,14 @@ def forward( # Inplace operation, updating query and key. pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) elif IS_XPU_SYSTEM: - sin = sin.expand(query.shape) - cos = cos.expand(query.shape) - torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key) + ipex.llm.modules.RotaryEmbedding.apply( + query, key, sin, cos, query.size(-1), True + ) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." ) - @classmethod def static(cls, config, dim, base, device): inv_freq = _create_inv_freq(dim, base, device) @@ -953,8 +968,6 @@ def get_cos_sin( cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) - if IS_XPU_SYSTEM: - return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2) # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 284fc017781..b9c42ce77dc 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,11 +1,18 @@ import torch -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: from vllm import cache_ops from vllm import attention_ops _PARTITION_SIZE = 512 +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex def reshape_and_cache( @@ -18,7 +25,9 @@ def reshape_and_cache( if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) elif IS_XPU_SYSTEM: - torch.xpu.reshape_and_cache(key, value, key_cache, value_cache, slots) + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) def attention( @@ -60,18 +69,18 @@ def attention( # to parallelize. if IS_XPU_SYSTEM: query = query.contiguous() - return torch.xpu.IpexPaged_attention( + return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, key_cache, value_cache, kv_head_mapping, + softmax_scale, block_tables, input_lengths, - softmax_scale, block_size, max_s, - None + None, ) use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512