Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reenable xpu for tgi #1939

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile_intel
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ USER root
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb

RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null

RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops
elif SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex


def _create_inv_freq(dim, base, device):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
elif SYSTEM == "rocm":
from vllm._C import ops
else:
raise RuntimeError(f"Unsupported system {SYSTEM}")
dropout_layer_norm = None


@dataclass
Expand Down
79 changes: 40 additions & 39 deletions server/text_generation_server/utils/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import math

from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.flash_attn_triton import triton_attention

if SYSTEM != "xpu":
from text_generation_server.utils.flash_attn_triton import triton_attention

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
Expand All @@ -15,43 +17,6 @@
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False

if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex

def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")

if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)


if SYSTEM in {"cuda", "rocm"}:
if not torch.cuda.is_available():
Expand Down Expand Up @@ -124,8 +89,44 @@ def attention(
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True

if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex

def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")

if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)

if HAS_FLASH_ATTN_V2_CUDA:
elif HAS_FLASH_ATTN_V2_CUDA:

def attention(
q,
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_cuda_free_memory(device, memory_fraction):
return free_memory


def get_xpu_free_memory(device):
def get_xpu_free_memory(device, memory_fraction):
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
free_memory = int(total_gpu_memory * 0.5)
return free_memory
Expand Down