Skip to content

Commit

Permalink
align to ipex llm ops
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi committed Mar 6, 2024
1 parent 3a79d15 commit b3989cf
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 19 deletions.
14 changes: 10 additions & 4 deletions server/text_generation_server/utils/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
from loguru import logger
import math

from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
HAS_FLASH_ATTN = True
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False

if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex

if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
Expand Down Expand Up @@ -90,7 +97,7 @@ def attention(
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return torch.xpu.varlen_fwd(
return ipex.llm.modules.VarlenAttention.apply(
q,
k,
v,
Expand All @@ -104,10 +111,9 @@ def attention(
False,
True,
False,
None
None,
)


if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd(
q,
Expand Down
33 changes: 23 additions & 10 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@
from accelerate import init_empty_weights

from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
from text_generation_server.utils.log import log_once

if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex

HAS_AWQ = True
try:
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
Expand Down Expand Up @@ -646,7 +653,13 @@ def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.ops.torch_ipex.fast_layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
out = ipex.llm.modules.FastLayerNorm.apply(
hidden_states,
self.normalized_shape,
self.eps,
self.weight,
self.bias,
)
return out, residual
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
if residual is not None:
Expand Down Expand Up @@ -698,8 +711,11 @@ def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.ops.torch_ipex.rms_norm(
hidden_states, [hidden_states.size(-1)], self.weight, self.variance_epsilon
out = ipex.llm.modules.RMSNorm.apply(
hidden_states,
[hidden_states.size(-1)],
self.weight,
self.variance_epsilon,
)
return out[0], residual
elif hidden_states.shape[-1] > 8192:
Expand Down Expand Up @@ -829,15 +845,14 @@ def forward(
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IS_XPU_SYSTEM:
sin = sin.expand(query.shape)
cos = cos.expand(query.shape)
torch.ops.torch_ipex.apply_rotary_embedding_half_qk(query, key, sin, cos, query, key)
ipex.llm.modules.RotaryEmbedding.apply(
query, key, sin, cos, query.size(-1), True
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)


@classmethod
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
Expand Down Expand Up @@ -953,8 +968,6 @@ def get_cos_sin(
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)

if IS_XPU_SYSTEM:
return cos.unsqueeze(1).repeat(1, 1, 2), sin.unsqueeze(1).repeat(1, 1, 2)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1)

Expand Down
19 changes: 14 additions & 5 deletions server/text_generation_server/utils/paged_attention.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import torch
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)

if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
from vllm import cache_ops
from vllm import attention_ops

_PARTITION_SIZE = 512

if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex


def reshape_and_cache(
Expand All @@ -18,7 +25,9 @@ def reshape_and_cache(
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
elif IS_XPU_SYSTEM:
torch.xpu.reshape_and_cache(key, value, key_cache, value_cache, slots)
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)


def attention(
Expand Down Expand Up @@ -60,18 +69,18 @@ def attention(
# to parallelize.
if IS_XPU_SYSTEM:
query = query.contiguous()
return torch.xpu.IpexPaged_attention(
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
softmax_scale,
block_size,
max_s,
None
None,
)

use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
Expand Down

0 comments on commit b3989cf

Please sign in to comment.