Skip to content

Commit 2d96e48

Browse files
jenniewJack-Khuusonghappy
authored and
vmpuri
committed
Add Intel XPU device support to generate and serve (#1361)
* add xpu * add xpu device * update * profile * update install * update * update * update --------- Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com> Co-authored-by: Guoqiong <guoqiong.song@intel.com>
1 parent e5543e2 commit 2d96e48

File tree

7 files changed

+56
-21
lines changed

7 files changed

+56
-21
lines changed

install/install_requirements.sh

+23-13
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,6 @@ VISION_NIGHTLY_VERSION=dev20241218
5959
# Nightly version for torchtune
6060
TUNE_NIGHTLY_VERSION=dev20241218
6161

62-
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
63-
(
64-
set -x
65-
$PIP_EXECUTABLE uninstall -y triton
66-
)
67-
6862
# The pip repository that hosts nightly torch packages. cpu by default.
6963
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
7064
# with cuda for faster execution on cuda GPUs.
@@ -74,16 +68,28 @@ then
7468
elif [[ -x "$(command -v rocminfo)" ]];
7569
then
7670
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2"
71+
elif [[ -x "$(command -v xpu-smi)" ]];
72+
then
73+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu"
7774
else
7875
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
7976
fi
8077

8178
# pip packages needed by exir.
82-
REQUIREMENTS_TO_INSTALL=(
83-
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
84-
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
85-
torchtune=="0.5.0.${TUNE_NIGHTLY_VERSION}"
86-
)
79+
if [[ -x "$(command -v xpu-smi)" ]];
80+
then
81+
REQUIREMENTS_TO_INSTALL=(
82+
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
83+
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
84+
torchtune=="0.5.0"
85+
)
86+
else
87+
REQUIREMENTS_TO_INSTALL=(
88+
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
89+
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
90+
torchtune=="0.5.0.${TUNE_NIGHTLY_VERSION}"
91+
)
92+
fi
8793

8894
#
8995
# First install requirements in install/requirements.txt. Older torch may be
@@ -95,6 +101,12 @@ REQUIREMENTS_TO_INSTALL=(
95101
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url "${TORCH_NIGHTLY_URL}"
96102
)
97103

104+
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
105+
(
106+
set -x
107+
$PIP_EXECUTABLE uninstall -y triton
108+
)
109+
98110
# Install the requirements. --extra-index-url tells pip to look for package
99111
# versions on the provided URL if they aren't available on the default URL.
100112
(
@@ -116,8 +128,6 @@ if [[ -x "$(command -v nvidia-smi)" ]]; then
116128
$PYTHON_EXECUTABLE torchchat/utils/scripts/patch_triton.py
117129
)
118130
fi
119-
120-
121131
(
122132
set -x
123133
$PIP_EXECUTABLE install evaluate=="0.4.3" lm-eval=="0.4.2" psutil=="6.0.0"

torchchat/cli/builder.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,12 @@ class BuilderArgs:
7272

7373
def __post_init__(self):
7474
if self.device is None:
75-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
75+
if torch.cuda.is_available():
76+
self.device = "cuda"
77+
elif torch.xpu.is_available():
78+
self.device = "xpu"
79+
else:
80+
self.device = "cpu"
7681

7782
if not (
7883
(self.checkpoint_path and self.checkpoint_path.is_file())

torchchat/cli/cli.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def _add_model_config_args(parser, verb: str) -> None:
176176
"--device",
177177
type=str,
178178
default=None,
179-
choices=["fast", "cpu", "cuda", "mps"],
180-
help="Hardware device to use. Options: fast, cpu, cuda, mps",
179+
choices=["fast", "cpu", "cuda", "mps", "xpu"],
180+
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
181181
)
182182

183183

torchchat/generate.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1203,8 +1203,10 @@ def callback(x, *, done_generating=False):
12031203
if hasattr(prof, "export_chrome_trace"):
12041204
if self.builder_args.device == "cpu":
12051205
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
1206-
else:
1206+
elif self.builder_args.device == "cuda":
12071207
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
1208+
else:
1209+
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
12081210
prof.export_chrome_trace(f"{self.profile}.json")
12091211

12101212
if start_pos >= max_seq_length:
@@ -1289,6 +1291,9 @@ def callback(x, *, done_generating=False):
12891291
)
12901292
if torch.cuda.is_available():
12911293
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
1294+
if torch.xpu.is_available():
1295+
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")
1296+
12921297

12931298

12941299
class DistributedGenerator(LocalGenerator):
@@ -1615,6 +1620,8 @@ def run_generator(
16151620
)
16161621
if torch.cuda.is_available():
16171622
torch.cuda.reset_peak_memory_stats()
1623+
if torch.xpu.is_available():
1624+
torch.xpu.reset_peak_memory_stats()
16181625

16191626
for _ in gen.chat(generator_args):
16201627
pass

torchchat/utils/build_utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ def find_multiple(n: int, k: int) -> int:
231231
def device_sync(device="cpu"):
232232
if "cuda" in device:
233233
torch.cuda.synchronize(device)
234+
elif "xpu" in device:
235+
torch.xpu.synchronize(device)
234236
elif ("cpu" in device) or ("mps" in device):
235237
pass
236238
else:
@@ -279,7 +281,8 @@ def get_device_str(device) -> str:
279281
device = (
280282
"cuda"
281283
if torch.cuda.is_available()
282-
else "mps" if is_mps_available() else "cpu"
284+
else "mps" if is_mps_available()
285+
else "xpu" if torch.xpu.is_available() else "cpu"
283286
)
284287
return device
285288
else:
@@ -291,7 +294,8 @@ def get_device(device) -> str:
291294
device = (
292295
"cuda"
293296
if torch.cuda.is_available()
294-
else "mps" if is_mps_available() else "cpu"
297+
else "mps" if is_mps_available()
298+
else "xpu" if torch.xpu.is_available() else "cpu"
295299
)
296300
return torch.device(device)
297301

torchchat/utils/device_info.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_device_info(device: str) -> str:
1414
"""Returns a human-readable description of the hardware based on a torch.device.type
1515
1616
Args:
17-
device: A torch.device.type string: one of {"cpu", "cuda"}.
17+
device: A torch.device.type string: one of {"cpu", "cuda", "xpu"}.
1818
Returns:
1919
str: A human-readable description of the hardware or an empty string if the device type is unhandled.
2020
@@ -37,4 +37,13 @@ def get_device_info(device: str) -> str:
3737
)
3838
if device == "cuda":
3939
return torch.cuda.get_device_name(0)
40+
if device == "xpu":
41+
return (
42+
check_output(
43+
["xpu-smi discovery |grep 'Device Name:'"], shell=True
44+
)
45+
.decode("utf-8")
46+
.split("\n")[0]
47+
.split("Device Name:")[1]
48+
)
4049
return ""

torchchat/utils/quantize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def quantize_model(
121121
else:
122122
ao_quant = True
123123
# Use tensor subclass API for int4 weight only.
124-
if device == "cuda" and quantizer == "linear:int4":
124+
if (device == "cuda" or device == "xpu") and quantizer == "linear:int4":
125125
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
126126
elif quantizer == "linear:int8":
127127
print("quantizer is linear int8")

0 commit comments

Comments
 (0)