Skip to content

Commit 78f4590

Browse files
authored
[Bugfix][XPU] fix silu_and_mul (#11823)
Signed-off-by: yan ma <yan.ma@intel.com>
1 parent 2f70249 commit 78f4590

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

vllm/model_executor/layers/activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def __init__(self):
6464
if current_platform.is_cuda_alike() or current_platform.is_cpu():
6565
self.op = torch.ops._C.silu_and_mul
6666
elif current_platform.is_xpu():
67-
import intel_extension_for_pytorch as ipex
68-
self.op = ipex.llm.functional.silu_and_mul
67+
from vllm._ipex_ops import ipex_ops
68+
self.op = ipex_ops.silu_and_mul
6969

7070
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
7171
"""PyTorch-native implementation equivalent to forward()."""

vllm/plugins/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,15 @@ def load_general_plugins():
6363
from vllm.platforms import current_platform
6464

6565
if current_platform.is_xpu():
66-
# see https://github.com/pytorch/pytorch/blob/8cada5cbe5450e17c26fb8b358116785324537b2/torch/_dynamo/config.py#L158 # noqa
67-
os.environ['TORCH_COMPILE_DISABLE'] = 'True'
66+
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
67+
torch._dynamo.config.disable = True
6868
if current_platform.is_hpu():
6969
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
7070
# does not support torch.compile
7171
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
7272
# torch.compile support
7373
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
7474
if is_lazy:
75-
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
7675
torch._dynamo.config.disable = True
7776
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
7877
# requires enabling lazy collectives

0 commit comments

Comments
 (0)