Skip to content

Commit

Permalink
[Hardware][intel GPU] bump up ipex version to 2.3 (vllm-project#8365)
Browse files Browse the repository at this point in the history
Co-authored-by: Yan Ma <yan.ma@intel.com>
  • Loading branch information
2 people authored and dtrifiro committed Sep 27, 2024
1 parent 9ebf28d commit 147926d
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 87 deletions.
12 changes: 10 additions & 2 deletions Dockerfile.xpu
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04
FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04

RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
rm /etc/apt/sources.list.d/intel-graphics.list && \
wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
chmod 644 /usr/share/keyrings/intel-graphics.gpg

RUN apt-get update -y \
&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1

RUN git clone https://github.com/intel/pti-gpu && \
cd pti-gpu/sdk && \
mkdir build && \
cd build && \
cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \
make -j && \
cmake --install . --config Release --prefix "/usr/local"

COPY ./ /workspace/vllm

WORKDIR /workspace/vllm
Expand Down
9 changes: 5 additions & 4 deletions requirements-xpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed.

torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl
intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl
torch == 2.3.1+cxx11.abi
intel-extension-for-pytorch == 2.3.110+xpu
oneccl_bind_pt == 2.3.100+xpu

triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
triton-xpu == 3.0.0b2

--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
98 changes: 29 additions & 69 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,27 @@ def _reshape_activation_tensor(

@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.silu_mul(x1, x2, out)
ipex.llm.functional.silu_and_mul(x, out)

@staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
ipex.llm.functional.gelu_and_mul(x, out)

@staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
ipex.llm.functional.gelu_and_mul(x, out)

@staticmethod
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x)

@staticmethod
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
def gelu_new(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x)

# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
@staticmethod
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
ipex.llm.functional.gelu_quick(x, out)

@staticmethod
def paged_attention_v1(
Expand Down Expand Up @@ -160,67 +158,26 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
is_neox: bool,
) -> None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
query = query.unsqueeze(0)
key = key.unsqueeze(0)

rotary_dim = cos_sin_cache.size(1)
query = query.view(*query.shape[:-1], -1, head_size)
key = key.view(*key.shape[:-1], -1, head_size)

query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]

cos_sin = cos_sin_cache[positions.long()]
cos, sin = cos_sin.chunk(2, dim=-1)

if is_neox:
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
rot_dim = cos_sin_cache.size(1)
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
head_size, cos_sin_cache,
is_neox, rot_dim)

@staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
query = query.unsqueeze(0)
key = key.unsqueeze(0)
cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
rotary_dim = cos_sin_cache.size(1)
query = query.view(*query.shape[:-1], -1, head_size)
key = key.view(*key.shape[:-1], -1, head_size)

query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]

cos_sin = cos_sin_cache[torch.add(positions,
cos_sin_cache_offsets).long()]
cos, sin = cos_sin.chunk(2, dim=-1)

if is_neox:
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
head_size, cos_sin_cache,
is_neox, rot_dim,
cos_sin_cache_offsets)

@staticmethod
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp)
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> torch.Tensor:
return ipex.llm.functional.rms_norm(input, weight, epsilon)

@staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
Expand All @@ -246,11 +203,14 @@ def varlen_attention(
return_softmax: bool,
gen_: torch.Generator,
) -> None:
ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
seqlen_k, max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax, gen_)
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(), seqlen_k.int(),
max_seqlen_q, max_seqlen_k,
pdropout, softmax_scale,
zero_tensors, is_causal,
return_softmax, gen_)

@staticmethod
def reshape_and_cache(
Expand Down
8 changes: 6 additions & 2 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ def swap_blocks(
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
from vllm._ipex_ops import ipex_ops as ops
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
from vllm._ipex_ops import ipex_ops as ops
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)


@dataclass
Expand Down
15 changes: 9 additions & 6 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
return ops.gelu_new(x)


class FastGELU(CustomOp):
Expand All @@ -136,9 +134,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out
return ops.gelu_fast(x)


class QuickGELU(CustomOp):
Expand All @@ -155,6 +151,13 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
ops.gelu_quick(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out

# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

Expand Down
5 changes: 1 addition & 4 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,11 @@ def forward_xpu(
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
return ops.rms_norm(
x,
self.weight.data,
self.variance_epsilon,
)
return out

def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
Expand Down

0 comments on commit 147926d

Please sign in to comment.