diff --git a/gui.sh b/gui.sh index 7f66a3eff..debca42fa 100755 --- a/gui.sh +++ b/gui.sh @@ -59,12 +59,32 @@ if [[ "$OSTYPE" == "darwin"* ]]; then fi else if [ "$RUNPOD" = false ]; then - REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt" + if [[ "$@" == *"--use-ipex"* ]]; then + REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux_ipex.txt" + else + REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt" + fi else REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_runpod.txt" fi fi +#Set OneAPI if it's not set by the user +if [[ "$@" == *"--use-ipex"* ]] +then + echo "Setting OneAPI environment" + if [ ! -x "$(command -v sycl-ls)" ] + then + if [[ -z "$ONEAPI_ROOT" ]] + then + ONEAPI_ROOT=/opt/intel/oneapi + fi + source $ONEAPI_ROOT/setvars.sh + fi + export NEOReadDebugKeys=1 + export ClDeviceGlobalMemSizeAvailablePercent=100 +fi + # Validate the requirements and run the script if successful if python "$SCRIPT_DIR/setup/validate_requirements.py" -r "$REQUIREMENTS_FILE"; then python "$SCRIPT_DIR/kohya_gui.py" "$@" diff --git a/kohya_gui.py b/kohya_gui.py index da5a04a3d..da25f2b8f 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -133,6 +133,10 @@ def UI(**kwargs): '--language', type=str, default=None, help='Set custom language' ) + parser.add_argument( + '--use-ipex', action='store_true', help='Use IPEX environment' + ) + args = parser.parse_args() UI( diff --git a/requirements_linux_ipex.txt b/requirements_linux_ipex.txt new file mode 100644 index 000000000..61d8a75f4 --- /dev/null +++ b/requirements_linux_ipex.txt @@ -0,0 +1,3 @@ +torch==2.0.1a0+cxx11.abi torchvision==0.15.2a0+cxx11.abi intel_extension_for_pytorch==2.0.110+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +tensorboard==2.12.3 tensorflow==2.12.0 intel-extension-for-tensorflow[gpu] +-r requirements.txt diff --git a/setup.sh b/setup.sh index 3f521f37c..5cb6a4e83 100755 --- a/setup.sh +++ b/setup.sh @@ -27,6 +27,7 @@ Options: -s, --skip-space-check Skip the 10Gb minimum storage space check. -u, --no-gui Skips launching the GUI. -v, --verbose Increase verbosity levels up to 3. + --use-ipex Use IPEX with Intel ARC GPUs. EOF } @@ -87,6 +88,7 @@ MAXVERBOSITY=6 DIR="" PARENT_DIR="" VENV_DIR="" +USE_IPEX=false # Function to get the distro name get_distro_name() { @@ -203,6 +205,8 @@ install_python_dependencies() { "lin"*) if [ "$RUNPOD" = true ]; then python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt + elif [ "$USE_IPEX" = true ]; then + python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt else python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt fi @@ -318,6 +322,7 @@ while getopts ":vb:d:g:inprus-:" opt; do s | skip-space-check) SKIP_SPACE_CHECK=true ;; u | no-gui) SKIP_GUI=true ;; v) ((VERBOSITY = VERBOSITY + 1)) ;; + use-ipex) USE_IPEX=true ;; h) display_help && exit 0 ;; *) display_help && exit 0 ;; esac diff --git a/setup/setup_common.py b/setup/setup_common.py index 8d94ca9f3..9a0ecdef3 100644 --- a/setup/setup_common.py +++ b/setup/setup_common.py @@ -195,12 +195,24 @@ def check_torch(): '/opt/rocm/bin/rocminfo' ): log.info('AMD toolkit detected') + elif (shutil.which('sycl-ls') is not None + or os.environ.get('ONEAPI_ROOT') is not None + or os.path.exists('/opt/intel/oneapi')): + log.info('Intel OneAPI toolkit detected') else: log.info('Using CPU-only Torch') try: import torch - + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() + os.environ.setdefault('NEOReadDebugKeys', '1') + os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100') + except Exception: + pass log.info(f'Torch {torch.__version__}') # Check if CUDA is available @@ -208,10 +220,14 @@ def check_torch(): log.warning('Torch reports CUDA not available') else: if torch.version.cuda: - # Log nVidia CUDA and cuDNN versions - log.info( - f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' - ) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + # Log Intel IPEX OneAPI version + log.info(f'Torch backend: Intel IPEX OneAPI {ipex.__version__}') + else: + # Log nVidia CUDA and cuDNN versions + log.info( + f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' + ) elif torch.version.hip: # Log AMD ROCm HIP version log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') @@ -222,9 +238,14 @@ def check_torch(): for device in [ torch.cuda.device(i) for i in range(torch.cuda.device_count()) ]: - log.info( - f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' - ) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + log.info( + f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' + ) + else: + log.info( + f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' + ) return int(torch.__version__[0]) except Exception as e: # log.warning(f'Could not load torch: {e}') diff --git a/setup/validate_requirements.py b/setup/validate_requirements.py index 73e94eb65..9d17ffda2 100644 --- a/setup/validate_requirements.py +++ b/setup/validate_requirements.py @@ -35,12 +35,22 @@ def check_torch(): '/opt/rocm/bin/rocminfo' ): log.info('AMD toolkit detected') + elif (shutil.which('sycl-ls') is not None + or os.environ.get('ONEAPI_ROOT') is not None + or os.path.exists('/opt/intel/oneapi')): + log.info('Intel OneAPI toolkit detected') else: log.info('Using CPU-only Torch') try: import torch - + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() + except Exception: + pass log.info(f'Torch {torch.__version__}') # Check if CUDA is available @@ -48,10 +58,14 @@ def check_torch(): log.warning('Torch reports CUDA not available') else: if torch.version.cuda: - # Log nVidia CUDA and cuDNN versions - log.info( - f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' - ) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + # Log Intel IPEX OneAPI version + log.info(f'Torch backend: Intel IPEX {ipex.__version__}') + else: + # Log nVidia CUDA and cuDNN versions + log.info( + f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' + ) elif torch.version.hip: # Log AMD ROCm HIP version log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') @@ -62,9 +76,14 @@ def check_torch(): for device in [ torch.cuda.device(i) for i in range(torch.cuda.device_count()) ]: - log.info( - f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' - ) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + log.info( + f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' + ) + else: + log.info( + f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' + ) return int(torch.__version__[0]) except Exception as e: log.error(f'Could not load torch: {e}')