Skip to content

Commit

Permalink
Merge pull request bmaltais#1499 from Disty0/master
Browse files Browse the repository at this point in the history
Add IPEX
  • Loading branch information
bmaltais authored Oct 4, 2023
2 parents 90406d9 + 76473d0 commit 1deb505
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 17 deletions.
22 changes: 21 additions & 1 deletion gui.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,32 @@ if [[ "$OSTYPE" == "darwin"* ]]; then
fi
else
if [ "$RUNPOD" = false ]; then
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt"
if [[ "$@" == *"--use-ipex"* ]]; then
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux_ipex.txt"
else
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt"
fi
else
REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_runpod.txt"
fi
fi

#Set OneAPI if it's not set by the user
if [[ "$@" == *"--use-ipex"* ]]
then
echo "Setting OneAPI environment"
if [ ! -x "$(command -v sycl-ls)" ]
then
if [[ -z "$ONEAPI_ROOT" ]]
then
ONEAPI_ROOT=/opt/intel/oneapi
fi
source $ONEAPI_ROOT/setvars.sh
fi
export NEOReadDebugKeys=1
export ClDeviceGlobalMemSizeAvailablePercent=100
fi

# Validate the requirements and run the script if successful
if python "$SCRIPT_DIR/setup/validate_requirements.py" -r "$REQUIREMENTS_FILE"; then
python "$SCRIPT_DIR/kohya_gui.py" "$@"
Expand Down
4 changes: 4 additions & 0 deletions kohya_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def UI(**kwargs):
'--language', type=str, default=None, help='Set custom language'
)

parser.add_argument(
'--use-ipex', action='store_true', help='Use IPEX environment'
)

args = parser.parse_args()

UI(
Expand Down
3 changes: 3 additions & 0 deletions requirements_linux_ipex.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch==2.0.1a0+cxx11.abi torchvision==0.15.2a0+cxx11.abi intel_extension_for_pytorch==2.0.110+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
tensorboard==2.12.3 tensorflow==2.12.0 intel-extension-for-tensorflow[gpu]
-r requirements.txt
5 changes: 5 additions & 0 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Options:
-s, --skip-space-check Skip the 10Gb minimum storage space check.
-u, --no-gui Skips launching the GUI.
-v, --verbose Increase verbosity levels up to 3.
--use-ipex Use IPEX with Intel ARC GPUs.
EOF
}

Expand Down Expand Up @@ -87,6 +88,7 @@ MAXVERBOSITY=6
DIR=""
PARENT_DIR=""
VENV_DIR=""
USE_IPEX=false

# Function to get the distro name
get_distro_name() {
Expand Down Expand Up @@ -203,6 +205,8 @@ install_python_dependencies() {
"lin"*)
if [ "$RUNPOD" = true ]; then
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt
elif [ "$USE_IPEX" = true ]; then
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt
else
python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt
fi
Expand Down Expand Up @@ -318,6 +322,7 @@ while getopts ":vb:d:g:inprus-:" opt; do
s | skip-space-check) SKIP_SPACE_CHECK=true ;;
u | no-gui) SKIP_GUI=true ;;
v) ((VERBOSITY = VERBOSITY + 1)) ;;
use-ipex) USE_IPEX=true ;;
h) display_help && exit 0 ;;
*) display_help && exit 0 ;;
esac
Expand Down
37 changes: 29 additions & 8 deletions setup/setup_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,23 +195,39 @@ def check_torch():
'/opt/rocm/bin/rocminfo'
):
log.info('AMD toolkit detected')
elif (shutil.which('sycl-ls') is not None
or os.environ.get('ONEAPI_ROOT') is not None
or os.path.exists('/opt/intel/oneapi')):
log.info('Intel OneAPI toolkit detected')
else:
log.info('Using CPU-only Torch')

try:
import torch

try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
os.environ.setdefault('NEOReadDebugKeys', '1')
os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100')
except Exception:
pass
log.info(f'Torch {torch.__version__}')

# Check if CUDA is available
if not torch.cuda.is_available():
log.warning('Torch reports CUDA not available')
else:
if torch.version.cuda:
# Log nVidia CUDA and cuDNN versions
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
if hasattr(torch, "xpu") and torch.xpu.is_available():
# Log Intel IPEX OneAPI version
log.info(f'Torch backend: Intel IPEX OneAPI {ipex.__version__}')
else:
# Log nVidia CUDA and cuDNN versions
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
elif torch.version.hip:
# Log AMD ROCm HIP version
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
Expand All @@ -222,9 +238,14 @@ def check_torch():
for device in [
torch.cuda.device(i) for i in range(torch.cuda.device_count())
]:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
)
if hasattr(torch, "xpu") and torch.xpu.is_available():
log.info(
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
)
else:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
)
return int(torch.__version__[0])
except Exception as e:
# log.warning(f'Could not load torch: {e}')
Expand Down
35 changes: 27 additions & 8 deletions setup/validate_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,37 @@ def check_torch():
'/opt/rocm/bin/rocminfo'
):
log.info('AMD toolkit detected')
elif (shutil.which('sycl-ls') is not None
or os.environ.get('ONEAPI_ROOT') is not None
or os.path.exists('/opt/intel/oneapi')):
log.info('Intel OneAPI toolkit detected')
else:
log.info('Using CPU-only Torch')

try:
import torch

try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
log.info(f'Torch {torch.__version__}')

# Check if CUDA is available
if not torch.cuda.is_available():
log.warning('Torch reports CUDA not available')
else:
if torch.version.cuda:
# Log nVidia CUDA and cuDNN versions
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
if hasattr(torch, "xpu") and torch.xpu.is_available():
# Log Intel IPEX OneAPI version
log.info(f'Torch backend: Intel IPEX {ipex.__version__}')
else:
# Log nVidia CUDA and cuDNN versions
log.info(
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}'
)
elif torch.version.hip:
# Log AMD ROCm HIP version
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
Expand All @@ -62,9 +76,14 @@ def check_torch():
for device in [
torch.cuda.device(i) for i in range(torch.cuda.device_count())
]:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
)
if hasattr(torch, "xpu") and torch.xpu.is_available():
log.info(
f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}'
)
else:
log.info(
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}'
)
return int(torch.__version__[0])
except Exception as e:
log.error(f'Could not load torch: {e}')
Expand Down

0 comments on commit 1deb505

Please sign in to comment.