@@ -59,12 +59,6 @@ VISION_NIGHTLY_VERSION=dev20241218
59
59
# Nightly version for torchtune
60
60
TUNE_NIGHTLY_VERSION=dev20241218
61
61
62
- # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
63
- (
64
- set -x
65
- $PIP_EXECUTABLE uninstall -y triton
66
- )
67
-
68
62
# The pip repository that hosts nightly torch packages. cpu by default.
69
63
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
70
64
# with cuda for faster execution on cuda GPUs.
74
68
elif [[ -x " $( command -v rocminfo) " ]];
75
69
then
76
70
TORCH_NIGHTLY_URL=" https://download.pytorch.org/whl/nightly/rocm6.2"
71
+ elif [[ -x " $( command -v xpu-smi) " ]];
72
+ then
73
+ TORCH_NIGHTLY_URL=" https://download.pytorch.org/whl/nightly/xpu"
77
74
else
78
75
TORCH_NIGHTLY_URL=" https://download.pytorch.org/whl/nightly/cpu"
79
76
fi
80
77
81
78
# pip packages needed by exir.
82
- REQUIREMENTS_TO_INSTALL=(
83
- torch==" 2.6.0.${PYTORCH_NIGHTLY_VERSION} "
84
- torchvision==" 0.22.0.${VISION_NIGHTLY_VERSION} "
85
- torchtune==" 0.5.0.${TUNE_NIGHTLY_VERSION} "
86
- )
79
+ if [[ -x " $( command -v xpu-smi) " ]];
80
+ then
81
+ REQUIREMENTS_TO_INSTALL=(
82
+ torch==" 2.6.0.${PYTORCH_NIGHTLY_VERSION} "
83
+ torchvision==" 0.22.0.${VISION_NIGHTLY_VERSION} "
84
+ torchtune==" 0.5.0"
85
+ )
86
+ else
87
+ REQUIREMENTS_TO_INSTALL=(
88
+ torch==" 2.6.0.${PYTORCH_NIGHTLY_VERSION} "
89
+ torchvision==" 0.22.0.${VISION_NIGHTLY_VERSION} "
90
+ torchtune==" 0.5.0.${TUNE_NIGHTLY_VERSION} "
91
+ )
92
+ fi
87
93
88
94
#
89
95
# First install requirements in install/requirements.txt. Older torch may be
@@ -95,6 +101,12 @@ REQUIREMENTS_TO_INSTALL=(
95
101
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url " ${TORCH_NIGHTLY_URL} "
96
102
)
97
103
104
+ # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
105
+ (
106
+ set -x
107
+ $PIP_EXECUTABLE uninstall -y triton
108
+ )
109
+
98
110
# Install the requirements. --extra-index-url tells pip to look for package
99
111
# versions on the provided URL if they aren't available on the default URL.
100
112
(
@@ -116,8 +128,6 @@ if [[ -x "$(command -v nvidia-smi)" ]]; then
116
128
$PYTHON_EXECUTABLE torchchat/utils/scripts/patch_triton.py
117
129
)
118
130
fi
119
-
120
-
121
131
(
122
132
set -x
123
133
$PIP_EXECUTABLE install evaluate==" 0.4.3" lm-eval==" 0.4.2" psutil==" 6.0.0"
0 commit comments