Skip to content

Commit

Permalink
refine get xpu free memory/enable Qwen2/gemma2/gemma/phi in intel pla…
Browse files Browse the repository at this point in the history
…tform (#2132)

* refine get xpu free memory

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable qwen2 in xpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable gemma/gemma2/phi in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored and ErikKaum committed Jul 26, 2024
1 parent e940b35 commit dca51e4
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 5 deletions.
3 changes: 2 additions & 1 deletion server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def attention(
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention(
Expand All @@ -28,7 +29,7 @@ def attention(
0.0,
softmax_scale,
False,
True,
causal,
False,
None,
)
Expand Down
8 changes: 8 additions & 0 deletions server/text_generation_server/models/flash_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM

tracer = trace.get_tracer(__name__)

Expand All @@ -32,6 +33,13 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma is only available on GPU")

Expand Down
8 changes: 8 additions & 0 deletions server/text_generation_server/models/flash_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM

tracer = trace.get_tracer(__name__)

Expand All @@ -32,6 +33,13 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma2 is only available on GPU")

Expand Down
8 changes: 8 additions & 0 deletions server/text_generation_server/models/flash_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM

tracer = trace.get_tracer(__name__)

Expand All @@ -32,6 +33,13 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashPhi is only available on GPU")

Expand Down
8 changes: 8 additions & 0 deletions server/text_generation_server/models/flash_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM

tracer = trace.get_tracer(__name__)

Expand All @@ -37,6 +38,13 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashQwen2 is only available on GPU")

Expand Down
12 changes: 8 additions & 4 deletions server/text_generation_server/utils/import_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from loguru import logger
import subprocess
import os


def is_ipex_available():
Expand All @@ -21,10 +22,13 @@ def get_cuda_free_memory(device, memory_fraction):
def get_xpu_free_memory(device, memory_fraction):
total_memory = torch.xpu.get_device_properties(device).total_memory
device_id = device.index
query = f"xpu-smi dump -d {device_id} -m 18 -n 1"
output = subprocess.check_output(query.split()).decode("utf-8").split("\n")
used_memory = float(output[1].split(",")[-1]) * 1024 * 1024
free_memory = int(total_memory * 0.95 - used_memory)
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
free_memory = max(
0,
int(
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
),
)
return free_memory


Expand Down

0 comments on commit dca51e4

Please sign in to comment.