From ef7bf311bd79c8eb5d91f245cf2be58e38faeff8 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Thu, 7 Sep 2023 22:40:29 +0300 Subject: [PATCH 01/10] Initial Intel ARC support with IPEX --- XTI_hijack.py | 7 + fine_tune.py | 7 + gen_img_diffusers.py | 7 + gui.sh | 17 +- kohya_gui.py | 4 + library/ipex/__init__.py | 164 +++++++++++++++++ library/ipex/diffusers.py | 262 +++++++++++++++++++++++++++ library/ipex/gradscaler.py | 179 ++++++++++++++++++ library/ipex/hijacks.py | 199 ++++++++++++++++++++ requirements_linux_ipex.txt | 3 + sdxl_gen_img.py | 7 + sdxl_minimal_inference.py | 7 + sdxl_train.py | 7 + sdxl_train_control_net_lllite.py | 7 + sdxl_train_control_net_lllite_alt.py | 7 + sdxl_train_network.py | 7 + sdxl_train_textual_inversion.py | 7 + setup.sh | 5 + setup/setup_common.py | 37 +++- setup/validate_requirements.py | 35 +++- train_controlnet.py | 7 + train_db.py | 7 + train_network.py | 7 + train_textual_inversion.py | 7 + train_textual_inversion_XTI.py | 7 + 25 files changed, 993 insertions(+), 17 deletions(-) create mode 100644 library/ipex/__init__.py create mode 100644 library/ipex/diffusers.py create mode 100644 library/ipex/gradscaler.py create mode 100644 library/ipex/hijacks.py create mode 100644 requirements_linux_ipex.txt diff --git a/XTI_hijack.py b/XTI_hijack.py index 36b5d3f2b..ec0849455 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,4 +1,11 @@ import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/fine_tune.py b/fine_tune.py index f89e897a8..f300d4688 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,6 +10,13 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 273c0dd86..0ea66cde2 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -65,6 +65,13 @@ import diffusers import numpy as np import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/gui.sh b/gui.sh index 7f66a3eff..253d8577f 100755 --- a/gui.sh +++ b/gui.sh @@ -59,12 +59,27 @@ if [[ "$OSTYPE" == "darwin"* ]]; then fi else if [ "$RUNPOD" = false ]; then - REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt" + if [[ "$@" == *"--use-ipex"* ]]; then + REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux_ipex.txt" + else + REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt" + fi else REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_runpod.txt" fi fi +#Set OneAPI environmet if it's not set by the user +if [[ "$@" == *"--use-ipex"* ]] && [ ! -x "$(command -v sycl-ls)" ] +then + echo "Setting OneAPI environment" + if [[ -z "$ONEAPI_ROOT" ]] + then + ONEAPI_ROOT=/opt/intel/oneapi + fi + source $ONEAPI_ROOT/setvars.sh +fi + # Validate the requirements and run the script if successful if python "$SCRIPT_DIR/setup/validate_requirements.py" -r "$REQUIREMENTS_FILE"; then python "$SCRIPT_DIR/kohya_gui.py" "$@" diff --git a/kohya_gui.py b/kohya_gui.py index da5a04a3d..da25f2b8f 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -133,6 +133,10 @@ def UI(**kwargs): '--language', type=str, default=None, help='Set custom language' ) + parser.add_argument( + '--use-ipex', action='store_true', help='Use IPEX environment' + ) + args = parser.parse_args() UI( diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py new file mode 100644 index 000000000..3d5f8da16 --- /dev/null +++ b/library/ipex/__init__.py @@ -0,0 +1,164 @@ +import os +import sys +import contextlib +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from .hijacks import ipex_hijacks + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +def ipex_init(): # pylint: disable=too-many-statements + #Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + + #Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + + #RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + #AMP: + torch.cuda.amp = torch.xpu.amp + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False + try: + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + + #C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 + + #Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda: True + torch.cuda.is_fp16_supported = lambda: True + torch.version.cuda = "11.7" + torch.cuda.get_device_capability = lambda: [11,7] + torch.cuda.get_device_properties.major = 11 + torch.cuda.get_device_properties.minor = 7 + torch.cuda.ipc_collect = lambda: None + torch.cuda.utilization = lambda: 0 + + ipex_hijacks() + try: + from .diffusers import ipex_diffusers # pylint: disable=import-outside-toplevel, import-error + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py new file mode 100644 index 000000000..18563b061 --- /dev/null +++ b/library/ipex/diffusers.py @@ -0,0 +1,262 @@ +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import torch.nn.functional as F # pylint: disable=ungrouped-imports +import diffusers #0.20.2 # pylint: disable=import-error + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +Attention = diffusers.models.attention_processor.Attention + +class SlicedAttnProcessor: # pylint: disable=too-few-public-methods + r""" + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, shape_three = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 + block_size = (batch_size_attention * query_tokens * shape_three) / 1024 * block_multiply #MB + split_2_slice_size = query_tokens + if block_size >= 4000: + do_split_2 = True + #Find something divisible with the query_tokens + while ((self.slice_size * split_2_slice_size * shape_three) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class AttnProcessor2_0: # pylint: disable=too-few-public-methods, invalid-name + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( # pylint: disable=too-many-arguments, too-many-statements, too-many-locals, too-many-branches + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + shape_one, batch_size_attention, query_tokens, shape_four = query.shape + block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 + block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB + split_slice_size = batch_size_attention + if block_size >= 4000: + do_split = True + #Find something divisible with the shape_one + while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB + split_2_slice_size = query_tokens + if split_block_size >= 4000: + do_split_2 = True + #Find something divisible with the batch_size_attention + while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + + query_slice = query[:, start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[:, start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = F.scaled_dot_product_attention( + query_slice, key_slice, value[:, start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False + ) + hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + else: + query_slice = query[:, start_idx:end_idx] + key_slice = key[:, start_idx:end_idx] + attn_mask_slice = attention_mask[:, start_idx:end_idx] if attention_mask is not None else None + + attn_slice = F.scaled_dot_product_attention( + query_slice, key_slice, value[:, start_idx:end_idx], + attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False + ) + hidden_states[:, start_idx:end_idx] = attn_slice + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +def ipex_diffusers(): + #ARC GPUs can't allocate more than 4GB to a single block: + diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor + diffusers.models.attention_processor.AttnProcessor2_0 = AttnProcessor2_0 diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py new file mode 100644 index 000000000..530212101 --- /dev/null +++ b/library/ipex/gradscaler.py @@ -0,0 +1,179 @@ +from collections import defaultdict +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +OptState = ipex.cpu.autocast._grad_scaler.OptState +_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator +_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state + +def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + # sync grad to master weight + if hasattr(optimizer, "sync_grad"): + optimizer.sync_grad() + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # -: is there a way to split by device and dtype without appending in the inner loop? + to_unscale = to_unscale.to("cpu") + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + core._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get("cpu"), + per_device_inv_scale.get("cpu"), + ) + + return per_device_found_inf._per_device_tensors + +def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device + ) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False + ) + optimizer_state["stage"] = OptState.UNSCALED + +def update(self, new_scale=None): + """ + Updates the scale factor. + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + Args: + new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device="cpu", non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + to_device = _scale.device + _scale = _scale.to("cpu") + _growth_tracker = _growth_tracker.to("cpu") + + core._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + + _scale = _scale.to(to_device) + _growth_tracker = _growth_tracker.to(to_device) + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + +def gradscaler_init(): + torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ + torch.xpu.amp.GradScaler.unscale_ = unscale_ + torch.xpu.amp.GradScaler.update = update + return torch.xpu.amp.GradScaler diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py new file mode 100644 index 000000000..f527e8836 --- /dev/null +++ b/library/ipex/hijacks.py @@ -0,0 +1,199 @@ +import contextlib +import importlib +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return + +class CondFunc: # pylint: disable=missing-class-docstring + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-1, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) + +_utils = torch.utils.data._utils +def _shutdown_workers(self): + if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None: + return + if hasattr(self, "_shutdown") and not self._shutdown: + self._shutdown = True + try: + if hasattr(self, '_pin_memory_thread'): + self._pin_memory_thread_done_event.set() + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: # pylint: disable=invalid-name + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: # pylint: disable=invalid-name + q.cancel_join_thread() + q.close() + finally: + if self._worker_pids_set: + _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: # pylint: disable=invalid-name + if w.is_alive(): + w.terminate() + +class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods + def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument + if isinstance(device_ids, list) and len(device_ids) > 1: + print("IPEX backend doesn't support DataParallel on multiple XPU devices") + return module.to("xpu") + +def return_null_context(*args, **kwargs): # pylint: disable=unused-argument + return contextlib.nullcontext() + +def check_device(device): + return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) + +def return_xpu(device): + return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + +def ipex_no_cuda(orig_func, *args, **kwargs): + torch.cuda.is_available = lambda: False + orig_func(*args, **kwargs) + torch.cuda.is_available = torch.xpu.is_available + +original_autocast = torch.autocast +def ipex_autocast(*args, **kwargs): + if len(args) > 1 and args[0] == "cuda": + return original_autocast("xpu", *args[1:], **kwargs) + else: + return original_autocast(*args, **kwargs) + +original_torch_cat = torch.cat +def torch_cat(tensor, *args, **kwargs): + if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): + return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) + else: + return original_torch_cat(tensor, *args, **kwargs) + +original_interpolate = torch.nn.functional.interpolate +def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments + if antialias or align_corners is not None: + return_device = tensor.device + return_dtype = tensor.dtype + return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype) + else: + return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + +original_linalg_solve = torch.linalg.solve +def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name + if A.device != torch.device("cpu") or B.device != torch.device("cpu"): + return_device = A.device + return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) + else: + return original_linalg_solve(A, B, *args, **kwargs) + +def ipex_hijacks(): + CondFunc('torch.Tensor.to', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.Tensor.cuda', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.empty', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.load', + lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), + lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) + CondFunc('torch.randn', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.ones', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.zeros', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.tensor', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.linspace', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + + CondFunc('torch.batch_norm', + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, + weight if weight is not None else torch.ones(input.size()[1], device=input.device), + bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) + CondFunc('torch.instance_norm', + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, + weight if weight is not None else torch.ones(input.size()[1], device=input.device), + bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) + + #Functions with dtype errors: + CondFunc('torch.nn.modules.GroupNorm.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.linear.Linear.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.conv.Conv2d.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, *args, **kwargs: orig_func(input, mat2.to(input.dtype), *args, **kwargs), + lambda orig_func, input, mat2, *args, **kwargs: input.dtype != mat2.dtype) + CondFunc('torch.nn.functional.layer_norm', + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + weight is not None and input.dtype != weight.data.dtype) + + #Diffusers Float64 (ARC GPUs doesn't support double or Float64): + if not torch.xpu.has_fp64_dtype(): + CondFunc('torch.from_numpy', + lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), + lambda orig_func, ndarray: ndarray.dtype == float) + + #Broken functions when torch.cuda.is_available is True: + CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', + lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), + lambda orig_func, *args, **kwargs: True) + + #Functions that make compile mad with CondFunc: + torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers + torch.nn.DataParallel = DummyDataParallel + torch.autocast = ipex_autocast + torch.cat = torch_cat + torch.linalg.solve = linalg_solve + torch.nn.functional.interpolate = interpolate + torch.backends.cuda.sdp_kernel = return_null_context diff --git a/requirements_linux_ipex.txt b/requirements_linux_ipex.txt new file mode 100644 index 000000000..e0fdb1e55 --- /dev/null +++ b/requirements_linux_ipex.txt @@ -0,0 +1,3 @@ +torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu -f https://developer.intel.com/ipex-whl-stable-xpu # no_verify leave this to specify not checking this a verification stage +tensorboard==2.12.3 tensorflow==2.12.0 intel-extension-for-tensorflow[gpu] +-r requirements.txt diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index c506ad3fc..f0d308511 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -17,6 +17,13 @@ import diffusers import numpy as np import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 5c8a0bd89..ff865629e 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -9,6 +9,13 @@ from einops import repeat import numpy as np import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler diff --git a/sdxl_train.py b/sdxl_train.py index 195467b00..6b255d679 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -10,6 +10,13 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import sdxl_model_util diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 09cf16438..f8169bdbf 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -11,6 +11,13 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/sdxl_train_control_net_lllite_alt.py b/sdxl_train_control_net_lllite_alt.py index 757194a10..61ebfb581 100644 --- a/sdxl_train_control_net_lllite_alt.py +++ b/sdxl_train_control_net_lllite_alt.py @@ -14,6 +14,13 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed import accelerate diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 8d3a81c3a..2de57c0ac 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,5 +1,12 @@ import argparse import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 123ca35a1..f5cca17b2 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -3,6 +3,13 @@ import regex import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/setup.sh b/setup.sh index 3f521f37c..5cb6a4e83 100755 --- a/setup.sh +++ b/setup.sh @@ -27,6 +27,7 @@ Options: -s, --skip-space-check Skip the 10Gb minimum storage space check. -u, --no-gui Skips launching the GUI. -v, --verbose Increase verbosity levels up to 3. + --use-ipex Use IPEX with Intel ARC GPUs. EOF } @@ -87,6 +88,7 @@ MAXVERBOSITY=6 DIR="" PARENT_DIR="" VENV_DIR="" +USE_IPEX=false # Function to get the distro name get_distro_name() { @@ -203,6 +205,8 @@ install_python_dependencies() { "lin"*) if [ "$RUNPOD" = true ]; then python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt + elif [ "$USE_IPEX" = true ]; then + python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt else python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt fi @@ -318,6 +322,7 @@ while getopts ":vb:d:g:inprus-:" opt; do s | skip-space-check) SKIP_SPACE_CHECK=true ;; u | no-gui) SKIP_GUI=true ;; v) ((VERBOSITY = VERBOSITY + 1)) ;; + use-ipex) USE_IPEX=true ;; h) display_help && exit 0 ;; *) display_help && exit 0 ;; esac diff --git a/setup/setup_common.py b/setup/setup_common.py index 8d94ca9f3..d9cacf35f 100644 --- a/setup/setup_common.py +++ b/setup/setup_common.py @@ -195,12 +195,24 @@ def check_torch(): '/opt/rocm/bin/rocminfo' ): log.info('AMD toolkit detected') + elif (shutil.which('sycl-ls') is not None + or os.environ.get('ONEAPI_ROOT') is not None + or os.path.exists('/opt/intel/oneapi')): + log.info('Intel OneAPI toolkit detected') else: log.info('Using CPU-only Torch') try: import torch - + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() + os.environ.setdefault('NEOReadDebugKeys', '1') + os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100') + except Exception: + pass log.info(f'Torch {torch.__version__}') # Check if CUDA is available @@ -208,10 +220,14 @@ def check_torch(): log.warning('Torch reports CUDA not available') else: if torch.version.cuda: - # Log nVidia CUDA and cuDNN versions - log.info( - f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' - ) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + # Log Intel IPEX OneAPI version + log.info(f'Torch backend: Intel IPEX OneAPI {ipex.__version__}') + else: + # Log nVidia CUDA and cuDNN versions + log.info( + f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' + ) elif torch.version.hip: # Log AMD ROCm HIP version log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') @@ -222,9 +238,14 @@ def check_torch(): for device in [ torch.cuda.device(i) for i in range(torch.cuda.device_count()) ]: - log.info( - f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' - ) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + log.info( + f'Torch detected GPU: Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' + ) + else: + log.info( + f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' + ) return int(torch.__version__[0]) except Exception as e: # log.warning(f'Could not load torch: {e}') diff --git a/setup/validate_requirements.py b/setup/validate_requirements.py index 73e94eb65..664ccfd9b 100644 --- a/setup/validate_requirements.py +++ b/setup/validate_requirements.py @@ -35,12 +35,22 @@ def check_torch(): '/opt/rocm/bin/rocminfo' ): log.info('AMD toolkit detected') + elif (shutil.which('sycl-ls') is not None + or os.environ.get('ONEAPI_ROOT') is not None + or os.path.exists('/opt/intel/oneapi')): + log.info('Intel OneAPI toolkit detected') else: log.info('Using CPU-only Torch') try: import torch - + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() + except Exception: + pass log.info(f'Torch {torch.__version__}') # Check if CUDA is available @@ -48,10 +58,14 @@ def check_torch(): log.warning('Torch reports CUDA not available') else: if torch.version.cuda: - # Log nVidia CUDA and cuDNN versions - log.info( - f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' - ) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + # Log Intel IPEX OneAPI version + log.info(f'Torch backend: Intel IPEX {ipex.__version__}') + else: + # Log nVidia CUDA and cuDNN versions + log.info( + f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' + ) elif torch.version.hip: # Log AMD ROCm HIP version log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') @@ -62,9 +76,14 @@ def check_torch(): for device in [ torch.cuda.device(i) for i in range(torch.cuda.device_count()) ]: - log.info( - f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' - ) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + log.info( + f'Torch detected GPU: Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' + ) + else: + log.info( + f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' + ) return int(torch.__version__[0]) except Exception as e: log.error(f'Could not load torch: {e}') diff --git a/train_controlnet.py b/train_controlnet.py index 988304f62..42da44125 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -11,6 +11,13 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/train_db.py b/train_db.py index 6dde7e9bf..feb147787 100644 --- a/train_db.py +++ b/train_db.py @@ -11,6 +11,13 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/train_network.py b/train_network.py index f752607e9..200fc2cfe 100644 --- a/train_network.py +++ b/train_network.py @@ -12,6 +12,13 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b65d524cf..1c7b7fcb2 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -7,6 +7,13 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 79c64cbeb..2c5673be1 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,6 +8,13 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler From f8fa3f3de6ed12049fa37e489022791c53f98621 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Fri, 8 Sep 2023 04:34:19 +0300 Subject: [PATCH 02/10] Fix SDP --- gui.sh | 13 +- library/ipex/__init__.py | 296 ++++++++++++++++++------------------ library/ipex/attention.py | 128 ++++++++++++++++ library/ipex/diffusers.py | 146 +----------------- library/ipex/hijacks.py | 11 +- requirements_linux_ipex.txt | 2 +- 6 files changed, 294 insertions(+), 302 deletions(-) create mode 100644 library/ipex/attention.py diff --git a/gui.sh b/gui.sh index 253d8577f..a7980ccb9 100755 --- a/gui.sh +++ b/gui.sh @@ -70,14 +70,19 @@ else fi #Set OneAPI environmet if it's not set by the user -if [[ "$@" == *"--use-ipex"* ]] && [ ! -x "$(command -v sycl-ls)" ] +if [[ "$@" == *"--use-ipex"* ]] then echo "Setting OneAPI environment" - if [[ -z "$ONEAPI_ROOT" ]] + if [ ! -x "$(command -v sycl-ls)" ] then - ONEAPI_ROOT=/opt/intel/oneapi + if [[ -z "$ONEAPI_ROOT" ]] + then + ONEAPI_ROOT=/opt/intel/oneapi + fi + source $ONEAPI_ROOT/setvars.sh fi - source $ONEAPI_ROOT/setvars.sh + export NEOReadDebugKeys=1 + export ClDeviceGlobalMemSizeAvailablePercent=100 fi # Validate the requirements and run the script if successful diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 3d5f8da16..9ec69012f 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -4,161 +4,167 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from .hijacks import ipex_hijacks +from .attention import attention_init # pylint: disable=protected-access, missing-function-docstring, line-too-long def ipex_init(): # pylint: disable=too-many-statements - #Replace cuda with xpu: - torch.cuda.current_device = torch.xpu.current_device - torch.cuda.current_stream = torch.xpu.current_stream - torch.cuda.device = torch.xpu.device - torch.cuda.device_count = torch.xpu.device_count - torch.cuda.device_of = torch.xpu.device_of - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard - torch.cuda.get_device_name = torch.xpu.get_device_name - torch.cuda.get_device_properties = torch.xpu.get_device_properties - torch.cuda.init = torch.xpu.init - torch.cuda.is_available = torch.xpu.is_available - torch.cuda.is_initialized = torch.xpu.is_initialized - torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device - torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize - torch.cuda.Event = torch.xpu.Event - torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.Tensor.cuda = torch.Tensor.xpu - torch.Tensor.is_cuda = torch.Tensor.is_xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda.Optional = torch.xpu.Optional - torch.cuda.__cached__ = torch.xpu.__cached__ - torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.Tuple = torch.xpu.Tuple - torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.Any = torch.xpu.Any - torch.cuda.__doc__ = torch.xpu.__doc__ - torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda._get_device_index = torch.xpu._get_device_index - torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os - torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.__annotations__ = torch.xpu.__annotations__ - torch.cuda.__package__ = torch.xpu.__package__ - torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.List = torch.xpu.List - torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty - torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + try: + #Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - #Memory: - torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): - torch.xpu.empty_cache = lambda: None - torch.cuda.empty_cache = torch.xpu.empty_cache - torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot - torch.cuda.memory_allocated = torch.xpu.memory_allocated - torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated - torch.cuda.memory_reserved = torch.xpu.memory_reserved - torch.cuda.memory_cached = torch.xpu.memory_reserved - torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved - torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved - torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats - torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + #Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - #RNG: - torch.cuda.get_rng_state = torch.xpu.get_rng_state - torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all - torch.cuda.set_rng_state = torch.xpu.set_rng_state - torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all - torch.cuda.manual_seed = torch.xpu.manual_seed - torch.cuda.manual_seed_all = torch.xpu.manual_seed_all - torch.cuda.seed = torch.xpu.seed - torch.cuda.seed_all = torch.xpu.seed_all - torch.cuda.initial_seed = torch.xpu.initial_seed + #RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed - #AMP: - torch.cuda.amp = torch.xpu.amp - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught + #AMP: + torch.cuda.amp = torch.xpu.amp + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - #C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 + #C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 - #Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] - torch._utils._get_available_device_type = lambda: "xpu" - torch.has_cuda = True - torch.cuda.has_half = True - torch.cuda.is_bf16_supported = lambda: True - torch.cuda.is_fp16_supported = lambda: True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 - torch.cuda.ipc_collect = lambda: None - torch.cuda.utilization = lambda: 0 + #Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.version.cuda = "11.7" + torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] + torch.cuda.get_device_properties.major = 11 + torch.cuda.get_device_properties.minor = 7 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - try: - from .diffusers import ipex_diffusers # pylint: disable=import-outside-toplevel, import-error - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + ipex_hijacks() + attention_init() + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + except Exception as e: + return False, e + return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py new file mode 100644 index 000000000..d7335bfaf --- /dev/null +++ b/library/ipex/attention.py @@ -0,0 +1,128 @@ +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +original_torch_bmm = torch.bmm +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] + block_multiply = 2.4 if input.dtype == torch.float32 else 1.2 + block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB + split_slice_size = batch_size_attention + if block_size >= 4000: + do_split = True + #Find something divisible with the input_tokens + while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB + split_2_slice_size = input_tokens + if split_block_size >= 4000: + do_split_2 = True + #Find something divisible with the input_tokens + while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) + else: + hidden_states[start_idx:end_idx] = original_torch_bmm( + input[start_idx:end_idx], + mat2[start_idx:end_idx], + out=out + ) + else: + return original_torch_bmm(input, mat2, out=out) + return hidden_states + +original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + shape_one, batch_size_attention, query_tokens, shape_four = query.shape + block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 + block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB + split_slice_size = batch_size_attention + if block_size >= 4000: + do_split = True + #Find something divisible with the shape_one + while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB + split_2_slice_size = query_tokens + if split_block_size >= 4000: + do_split_2 = True + #Find something divisible with the batch_size_attention + while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx, start_idx_2:end_idx_2], + key[:, start_idx:end_idx, start_idx_2:end_idx_2], + value[:, start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx], + key[:, start_idx:end_idx], + value[:, start_idx:end_idx], + attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + return original_scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal + ) + return hidden_states + +def attention_init(): + #ARC GPUs can't allocate more than 4GB to a single block: + torch.bmm = torch_bmm + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 18563b061..3435abe14 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,12 +1,9 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import torch.nn.functional as F # pylint: disable=ungrouped-imports import diffusers #0.20.2 # pylint: disable=import-error # pylint: disable=protected-access, missing-function-docstring, line-too-long -Attention = diffusers.models.attention_processor.Attention - class SlicedAttnProcessor: # pylint: disable=too-few-public-methods r""" Processor for implementing sliced attention. @@ -20,7 +17,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches + def __call__(self, attn: diffusers.models.attention_processor.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches residual = hidden_states input_ndim = hidden_states.ndim @@ -116,147 +113,6 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states -class AttnProcessor2_0: # pylint: disable=too-few-public-methods, invalid-name - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( # pylint: disable=too-many-arguments, too-many-statements, too-many-locals, too-many-branches - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - shape_one, batch_size_attention, query_tokens, shape_four = query.shape - block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 - block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB - split_slice_size = batch_size_attention - if block_size >= 4000: - do_split = True - #Find something divisible with the shape_one - while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False - - split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB - split_2_slice_size = query_tokens - if split_block_size >= 4000: - do_split_2 = True - #Find something divisible with the batch_size_attention - while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - - if do_split: - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - - query_slice = query[:, start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[:, start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = F.scaled_dot_product_attention( - query_slice, key_slice, value[:, start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False - ) - hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - else: - query_slice = query[:, start_idx:end_idx] - key_slice = key[:, start_idx:end_idx] - attn_mask_slice = attention_mask[:, start_idx:end_idx] if attention_mask is not None else None - - attn_slice = F.scaled_dot_product_attention( - query_slice, key_slice, value[:, start_idx:end_idx], - attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False - ) - hidden_states[:, start_idx:end_idx] = attn_slice - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - def ipex_diffusers(): #ARC GPUs can't allocate more than 4GB to a single block: diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor - diffusers.models.attention_processor.AttnProcessor2_0 = AttnProcessor2_0 diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index f527e8836..78d7e034b 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -34,7 +34,7 @@ def __call__(self, *args, **kwargs): _utils = torch.utils.data._utils def _shutdown_workers(self): - if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None: + if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: return if hasattr(self, "_shutdown") and not self._shutdown: self._shutdown = True @@ -50,13 +50,13 @@ def _shutdown_workers(self): if self._persistent_workers or self._workers_status[worker_id]: self._mark_worker_as_unavailable(worker_id, shutdown=True) for w in self._workers: # pylint: disable=invalid-name - w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) for q in self._index_queues: # pylint: disable=invalid-name q.cancel_join_thread() q.close() finally: if self._worker_pids_set: - _utils.signal_handling._remove_worker_pids(id(self)) + torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) self._worker_pids_set = False for w in self._workers: # pylint: disable=invalid-name if w.is_alive(): @@ -84,7 +84,7 @@ def ipex_no_cuda(orig_func, *args, **kwargs): original_autocast = torch.autocast def ipex_autocast(*args, **kwargs): - if len(args) > 1 and args[0] == "cuda": + if len(args) > 0 and args[0] == "cuda": return original_autocast("xpu", *args[1:], **kwargs) else: return original_autocast(*args, **kwargs) @@ -169,9 +169,6 @@ def ipex_hijacks(): CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.bmm', - lambda orig_func, input, mat2, *args, **kwargs: orig_func(input, mat2.to(input.dtype), *args, **kwargs), - lambda orig_func, input, mat2, *args, **kwargs: input.dtype != mat2.dtype) CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), diff --git a/requirements_linux_ipex.txt b/requirements_linux_ipex.txt index e0fdb1e55..28c0bc6b1 100644 --- a/requirements_linux_ipex.txt +++ b/requirements_linux_ipex.txt @@ -1,3 +1,3 @@ -torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu -f https://developer.intel.com/ipex-whl-stable-xpu # no_verify leave this to specify not checking this a verification stage +torch==2.0.1a0+cxx11.abi torchvision==0.15.2a0+cxx11.abi intel_extension_for_pytorch==2.0.110+xpu -f https://developer.intel.com/ipex-whl-stable-xpu # no_verify leave this to specify not checking this a verification stage tensorboard==2.12.3 tensorflow==2.12.0 intel-extension-for-tensorflow[gpu] -r requirements.txt From f2b6c4603e08f1708c6369b5d0880a5cb9b0f24b Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 11 Sep 2023 23:39:24 +0300 Subject: [PATCH 03/10] Fix finetune --- library/model_util.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/library/model_util.py b/library/model_util.py index 860c170b2..00a3c0495 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -4,6 +4,13 @@ import math import os import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel From 15186d5514485b6e2f34b4e7d413533e4e661308 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Thu, 14 Sep 2023 15:54:55 +0300 Subject: [PATCH 04/10] fix diffusers 0.21 lazy import --- library/ipex/diffusers.py | 5 +++-- library/ipex/hijacks.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 3435abe14..4c39896ed 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,6 +1,7 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.20.2 # pylint: disable=import-error +import diffusers #0.21.1 # pylint: disable=import-error +from diffusers.models.attention_processor import Attention # pylint: disable=protected-access, missing-function-docstring, line-too-long @@ -17,7 +18,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: diffusers.models.attention_processor.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches residual = hidden_states input_ndim = hidden_states.ndim diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 78d7e034b..77ed5419a 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -75,7 +75,7 @@ def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) def return_xpu(device): - return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" def ipex_no_cuda(orig_func, *args, **kwargs): torch.cuda.is_available = lambda: False From c1a454f95c2200db3d27dbf4b90bf86f8124b6af Mon Sep 17 00:00:00 2001 From: Disty0 Date: Sun, 17 Sep 2023 16:20:11 +0300 Subject: [PATCH 05/10] Update SDPA slice rate --- library/ipex/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/ipex/attention.py b/library/ipex/attention.py index d7335bfaf..fc4ab6e26 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -65,7 +65,7 @@ def torch_bmm(input, mat2, *, out=None): def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): #ARC GPUs can't allocate more than 4GB to a single block, Slice it: shape_one, batch_size_attention, query_tokens, shape_four = query.shape - block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 + block_multiply = 3.6 if query.dtype == torch.float32 else 1.8 block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB split_slice_size = batch_size_attention if block_size >= 4000: From 97d60da67b991eb130e06eb37d605b24f038a63d Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 18 Sep 2023 15:37:24 +0300 Subject: [PATCH 06/10] Fix SDPA --- library/ipex/attention.py | 52 ++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/library/ipex/attention.py b/library/ipex/attention.py index fc4ab6e26..e38689f21 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -64,7 +64,13 @@ def torch_bmm(input, mat2, *, out=None): original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - shape_one, batch_size_attention, query_tokens, shape_four = query.shape + if len(query.shape) == 3: + batch_size_attention, query_tokens, shape_four = query.shape + shape_one = 1 + no_shape_one = True + else: + shape_one, batch_size_attention, query_tokens, shape_four = query.shape + no_shape_one = False block_multiply = 3.6 if query.dtype == torch.float32 else 1.8 block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB split_slice_size = batch_size_attention @@ -101,21 +107,39 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx, start_idx_2:end_idx_2], - key[:, start_idx:end_idx, start_idx_2:end_idx_2], - value[:, start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + if no_shape_one: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_2:end_idx_2], + key[start_idx:end_idx, start_idx_2:end_idx_2], + value[start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx, start_idx_2:end_idx_2], + key[:, start_idx:end_idx, start_idx_2:end_idx_2], + value[:, start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + if no_shape_one: + hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx], + key[:, start_idx:end_idx], + value[:, start_idx:end_idx], + attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal ) - else: - hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx], - key[:, start_idx:end_idx], - value[:, start_idx:end_idx], - attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) else: return original_scaled_dot_product_attention( query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal From 912c9d0456459f7510d75b07fbecdbd61bddaa64 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Sun, 24 Sep 2023 13:50:12 +0300 Subject: [PATCH 07/10] Remove non GUI changes --- XTI_hijack.py | 7 - fine_tune.py | 7 - gen_img_diffusers.py | 7 - library/ipex/__init__.py | 170 ----------------------- library/ipex/attention.py | 152 --------------------- library/ipex/diffusers.py | 119 ---------------- library/ipex/gradscaler.py | 179 ------------------------ library/ipex/hijacks.py | 196 --------------------------- library/model_util.py | 7 - sdxl_gen_img.py | 7 - sdxl_minimal_inference.py | 7 - sdxl_train.py | 7 - sdxl_train_control_net_lllite.py | 7 - sdxl_train_control_net_lllite_alt.py | 7 - sdxl_train_network.py | 7 - sdxl_train_textual_inversion.py | 7 - train_controlnet.py | 7 - train_db.py | 7 - train_network.py | 7 - train_textual_inversion.py | 7 - train_textual_inversion_XTI.py | 7 - 21 files changed, 928 deletions(-) delete mode 100644 library/ipex/__init__.py delete mode 100644 library/ipex/attention.py delete mode 100644 library/ipex/diffusers.py delete mode 100644 library/ipex/gradscaler.py delete mode 100644 library/ipex/hijacks.py diff --git a/XTI_hijack.py b/XTI_hijack.py index ec0849455..36b5d3f2b 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,11 +1,4 @@ import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/fine_tune.py b/fine_tune.py index f300d4688..f89e897a8 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,13 +10,6 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 0ea66cde2..273c0dd86 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -65,13 +65,6 @@ import diffusers import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py deleted file mode 100644 index 9ec69012f..000000000 --- a/library/ipex/__init__.py +++ /dev/null @@ -1,170 +0,0 @@ -import os -import sys -import contextlib -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -from .hijacks import ipex_hijacks -from .attention import attention_init - -# pylint: disable=protected-access, missing-function-docstring, line-too-long - -def ipex_init(): # pylint: disable=too-many-statements - try: - #Replace cuda with xpu: - torch.cuda.current_device = torch.xpu.current_device - torch.cuda.current_stream = torch.xpu.current_stream - torch.cuda.device = torch.xpu.device - torch.cuda.device_count = torch.xpu.device_count - torch.cuda.device_of = torch.xpu.device_of - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard - torch.cuda.get_device_name = torch.xpu.get_device_name - torch.cuda.get_device_properties = torch.xpu.get_device_properties - torch.cuda.init = torch.xpu.init - torch.cuda.is_available = torch.xpu.is_available - torch.cuda.is_initialized = torch.xpu.is_initialized - torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device - torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize - torch.cuda.Event = torch.xpu.Event - torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.Tensor.cuda = torch.Tensor.xpu - torch.Tensor.is_cuda = torch.Tensor.is_xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda.Optional = torch.xpu.Optional - torch.cuda.__cached__ = torch.xpu.__cached__ - torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.Tuple = torch.xpu.Tuple - torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.Any = torch.xpu.Any - torch.cuda.__doc__ = torch.xpu.__doc__ - torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda._get_device_index = torch.xpu._get_device_index - torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os - torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.__annotations__ = torch.xpu.__annotations__ - torch.cuda.__package__ = torch.xpu.__package__ - torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.List = torch.xpu.List - torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty - torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - - #Memory: - torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): - torch.xpu.empty_cache = lambda: None - torch.cuda.empty_cache = torch.xpu.empty_cache - torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot - torch.cuda.memory_allocated = torch.xpu.memory_allocated - torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated - torch.cuda.memory_reserved = torch.xpu.memory_reserved - torch.cuda.memory_cached = torch.xpu.memory_reserved - torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved - torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved - torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats - torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - - #RNG: - torch.cuda.get_rng_state = torch.xpu.get_rng_state - torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all - torch.cuda.set_rng_state = torch.xpu.set_rng_state - torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all - torch.cuda.manual_seed = torch.xpu.manual_seed - torch.cuda.manual_seed_all = torch.xpu.manual_seed_all - torch.cuda.seed = torch.xpu.seed - torch.cuda.seed_all = torch.xpu.seed_all - torch.cuda.initial_seed = torch.xpu.initial_seed - - #AMP: - torch.cuda.amp = torch.xpu.amp - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught - try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - - #C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 - - #Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] - torch._utils._get_available_device_type = lambda: "xpu" - torch.has_cuda = True - torch.cuda.has_half = True - torch.cuda.is_bf16_supported = lambda *args, **kwargs: True - torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 - torch.cuda.ipc_collect = lambda *args, **kwargs: None - torch.cuda.utilization = lambda *args, **kwargs: 0 - - ipex_hijacks() - attention_init() - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass - except Exception as e: - return False, e - return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py deleted file mode 100644 index e38689f21..000000000 --- a/library/ipex/attention.py +++ /dev/null @@ -1,152 +0,0 @@ -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import - -# pylint: disable=protected-access, missing-function-docstring, line-too-long - -original_torch_bmm = torch.bmm -def torch_bmm(input, mat2, *, out=None): - if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) - - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] - block_multiply = 2.4 if input.dtype == torch.float32 else 1.2 - block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB - split_slice_size = batch_size_attention - if block_size >= 4000: - do_split = True - #Find something divisible with the input_tokens - while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False - - split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB - split_2_slice_size = input_tokens - if split_block_size >= 4000: - do_split_2 = True - #Find something divisible with the input_tokens - while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - - if do_split: - hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2], - mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out - ) - else: - hidden_states[start_idx:end_idx] = original_torch_bmm( - input[start_idx:end_idx], - mat2[start_idx:end_idx], - out=out - ) - else: - return original_torch_bmm(input, mat2, out=out) - return hidden_states - -original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - if len(query.shape) == 3: - batch_size_attention, query_tokens, shape_four = query.shape - shape_one = 1 - no_shape_one = True - else: - shape_one, batch_size_attention, query_tokens, shape_four = query.shape - no_shape_one = False - block_multiply = 3.6 if query.dtype == torch.float32 else 1.8 - block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB - split_slice_size = batch_size_attention - if block_size >= 4000: - do_split = True - #Find something divisible with the shape_one - while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False - - split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB - split_2_slice_size = query_tokens - if split_block_size >= 4000: - do_split_2 = True - #Find something divisible with the batch_size_attention - while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - - if do_split: - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if no_shape_one: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[start_idx:end_idx, start_idx_2:end_idx_2], - key[start_idx:end_idx, start_idx_2:end_idx_2], - value[start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx, start_idx_2:end_idx_2], - key[:, start_idx:end_idx, start_idx_2:end_idx_2], - value[:, start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - if no_shape_one: - hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx], - key[:, start_idx:end_idx], - value[:, start_idx:end_idx], - attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - return original_scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal - ) - return hidden_states - -def attention_init(): - #ARC GPUs can't allocate more than 4GB to a single block: - torch.bmm = torch_bmm - torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py deleted file mode 100644 index 4c39896ed..000000000 --- a/library/ipex/diffusers.py +++ /dev/null @@ -1,119 +0,0 @@ -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.21.1 # pylint: disable=import-error -from diffusers.models.attention_processor import Attention - -# pylint: disable=protected-access, missing-function-docstring, line-too-long - -class SlicedAttnProcessor: # pylint: disable=too-few-public-methods - r""" - Processor for implementing sliced attention. - - Args: - slice_size (`int`, *optional*): - The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and - `attention_head_dim` must be a multiple of the `slice_size`. - """ - - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches - residual = hidden_states - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention, query_tokens, shape_three = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 - block_size = (batch_size_attention * query_tokens * shape_three) / 1024 * block_multiply #MB - split_2_slice_size = query_tokens - if block_size >= 4000: - do_split_2 = True - #Find something divisible with the query_tokens - while ((self.slice_size * split_2_slice_size * shape_three) / 1024 * block_multiply) > 4000: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - - for i in range(batch_size_attention // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - else: - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - -def ipex_diffusers(): - #ARC GPUs can't allocate more than 4GB to a single block: - diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py deleted file mode 100644 index 530212101..000000000 --- a/library/ipex/gradscaler.py +++ /dev/null @@ -1,179 +0,0 @@ -from collections import defaultdict -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import - -# pylint: disable=protected-access, missing-function-docstring, line-too-long - -OptState = ipex.cpu.autocast._grad_scaler.OptState -_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator -_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state - -def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument - per_device_inv_scale = _MultiDeviceReplicator(inv_scale) - per_device_found_inf = _MultiDeviceReplicator(found_inf) - - # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. - # There could be hundreds of grads, so we'd like to iterate through them just once. - # However, we don't know their devices or dtypes in advance. - - # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict - # Google says mypy struggles with defaultdicts type annotations. - per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] - # sync grad to master weight - if hasattr(optimizer, "sync_grad"): - optimizer.sync_grad() - with torch.no_grad(): - for group in optimizer.param_groups: - for param in group["params"]: - if param.grad is None: - continue - if (not allow_fp16) and param.grad.dtype == torch.float16: - raise ValueError("Attempting to unscale FP16 gradients.") - if param.grad.is_sparse: - # is_coalesced() == False means the sparse grad has values with duplicate indices. - # coalesce() deduplicates indices and adds all values that have the same index. - # For scaled fp16 values, there's a good chance coalescing will cause overflow, - # so we should check the coalesced _values(). - if param.grad.dtype is torch.float16: - param.grad = param.grad.coalesce() - to_unscale = param.grad._values() - else: - to_unscale = param.grad - - # -: is there a way to split by device and dtype without appending in the inner loop? - to_unscale = to_unscale.to("cpu") - per_device_and_dtype_grads[to_unscale.device][ - to_unscale.dtype - ].append(to_unscale) - - for _, per_dtype_grads in per_device_and_dtype_grads.items(): - for grads in per_dtype_grads.values(): - core._amp_foreach_non_finite_check_and_unscale_( - grads, - per_device_found_inf.get("cpu"), - per_device_inv_scale.get("cpu"), - ) - - return per_device_found_inf._per_device_tensors - -def unscale_(self, optimizer): - """ - Divides ("unscales") the optimizer's gradient tensors by the scale factor. - :meth:`unscale_` is optional, serving cases where you need to - :ref:`modify or inspect gradients` - between the backward pass(es) and :meth:`step`. - If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. - Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: - ... - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - scaler.step(optimizer) - scaler.update() - Args: - optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. - .. warning:: - :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, - and only after all gradients for that optimizer's assigned parameters have been accumulated. - Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. - .. warning:: - :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. - """ - if not self._enabled: - return - - self._check_scale_growth_tracker("unscale_") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise - raise RuntimeError( - "unscale_() has already been called on this optimizer since the last update()." - ) - elif optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError("unscale_() is being called after step().") - - # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. - assert self._scale is not None - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) - found_inf = torch.full( - (1,), 0.0, dtype=torch.float32, device=self._scale.device - ) - - optimizer_state["found_inf_per_device"] = self._unscale_grads_( - optimizer, inv_scale, found_inf, False - ) - optimizer_state["stage"] = OptState.UNSCALED - -def update(self, new_scale=None): - """ - Updates the scale factor. - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not - used directly, it's used to fill GradScaler's internal scale tensor. So if - ``new_scale`` was a tensor, later in-place changes to that tensor will not further - affect the scale GradScaler uses internally.) - Args: - new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] - else: - reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [ - found_inf.to(device="cpu", non_blocking=True) - for state in self._per_optimizer_states.values() - for found_inf in state["found_inf_per_device"].values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf_combined += found_infs[i] - - to_device = _scale.device - _scale = _scale.to("cpu") - _growth_tracker = _growth_tracker.to("cpu") - - core._amp_update_scale_( - _scale, - _growth_tracker, - found_inf_combined, - self._growth_factor, - self._backoff_factor, - self._growth_interval, - ) - - _scale = _scale.to(to_device) - _growth_tracker = _growth_tracker.to(to_device) - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - -def gradscaler_init(): - torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ - torch.xpu.amp.GradScaler.unscale_ = unscale_ - torch.xpu.amp.GradScaler.update = update - return torch.xpu.amp.GradScaler diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py deleted file mode 100644 index 77ed5419a..000000000 --- a/library/ipex/hijacks.py +++ /dev/null @@ -1,196 +0,0 @@ -import contextlib -import importlib -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import - -# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return - -class CondFunc: # pylint: disable=missing-class-docstring - def __new__(cls, orig_func, sub_func, cond_func): - self = super(CondFunc, cls).__new__(cls) - if isinstance(orig_func, str): - func_path = orig_func.split('.') - for i in range(len(func_path)-1, -1, -1): - try: - resolved_obj = importlib.import_module('.'.join(func_path[:i])) - break - except ImportError: - pass - for attr_name in func_path[i:-1]: - resolved_obj = getattr(resolved_obj, attr_name) - orig_func = getattr(resolved_obj, func_path[-1]) - setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) - self.__init__(orig_func, sub_func, cond_func) - return lambda *args, **kwargs: self(*args, **kwargs) - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func - def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) - else: - return self.__orig_func(*args, **kwargs) - -_utils = torch.utils.data._utils -def _shutdown_workers(self): - if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: - return - if hasattr(self, "_shutdown") and not self._shutdown: - self._shutdown = True - try: - if hasattr(self, '_pin_memory_thread'): - self._pin_memory_thread_done_event.set() - self._worker_result_queue.put((None, None)) - self._pin_memory_thread.join() - self._worker_result_queue.cancel_join_thread() - self._worker_result_queue.close() - self._workers_done_event.set() - for worker_id in range(len(self._workers)): - if self._persistent_workers or self._workers_status[worker_id]: - self._mark_worker_as_unavailable(worker_id, shutdown=True) - for w in self._workers: # pylint: disable=invalid-name - w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) - for q in self._index_queues: # pylint: disable=invalid-name - q.cancel_join_thread() - q.close() - finally: - if self._worker_pids_set: - torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) - self._worker_pids_set = False - for w in self._workers: # pylint: disable=invalid-name - if w.is_alive(): - w.terminate() - -class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods - def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument - if isinstance(device_ids, list) and len(device_ids) > 1: - print("IPEX backend doesn't support DataParallel on multiple XPU devices") - return module.to("xpu") - -def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() - -def check_device(device): - return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) - -def return_xpu(device): - return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" - -def ipex_no_cuda(orig_func, *args, **kwargs): - torch.cuda.is_available = lambda: False - orig_func(*args, **kwargs) - torch.cuda.is_available = torch.xpu.is_available - -original_autocast = torch.autocast -def ipex_autocast(*args, **kwargs): - if len(args) > 0 and args[0] == "cuda": - return original_autocast("xpu", *args[1:], **kwargs) - else: - return original_autocast(*args, **kwargs) - -original_torch_cat = torch.cat -def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) - else: - return original_torch_cat(tensor, *args, **kwargs) - -original_interpolate = torch.nn.functional.interpolate -def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments - if antialias or align_corners is not None: - return_device = tensor.device - return_dtype = tensor.dtype - return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype) - else: - return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) - -original_linalg_solve = torch.linalg.solve -def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name - if A.device != torch.device("cpu") or B.device != torch.device("cpu"): - return_device = A.device - return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) - else: - return original_linalg_solve(A, B, *args, **kwargs) - -def ipex_hijacks(): - CondFunc('torch.Tensor.to', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.Tensor.cuda', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.empty', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.load', - lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), - lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) - CondFunc('torch.randn', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.ones', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.zeros', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.tensor', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.linspace', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") - - CondFunc('torch.batch_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - CondFunc('torch.instance_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - - #Functions with dtype errors: - CondFunc('torch.nn.modules.GroupNorm.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.linear.Linear.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.conv.Conv2d.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.functional.layer_norm', - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - weight is not None and input.dtype != weight.data.dtype) - - #Diffusers Float64 (ARC GPUs doesn't support double or Float64): - if not torch.xpu.has_fp64_dtype(): - CondFunc('torch.from_numpy', - lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), - lambda orig_func, ndarray: ndarray.dtype == float) - - #Broken functions when torch.cuda.is_available is True: - CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', - lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), - lambda orig_func, *args, **kwargs: True) - - #Functions that make compile mad with CondFunc: - torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers - torch.nn.DataParallel = DummyDataParallel - torch.autocast = ipex_autocast - torch.cat = torch_cat - torch.linalg.solve = linalg_solve - torch.nn.functional.interpolate = interpolate - torch.backends.cuda.sdp_kernel = return_null_context diff --git a/library/model_util.py b/library/model_util.py index 00a3c0495..860c170b2 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -4,13 +4,6 @@ import math import os import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index f0d308511..c506ad3fc 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -17,13 +17,6 @@ import diffusers import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index ff865629e..5c8a0bd89 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -9,13 +9,6 @@ from einops import repeat import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler diff --git a/sdxl_train.py b/sdxl_train.py index 6b255d679..195467b00 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -10,13 +10,6 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import sdxl_model_util diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f8169bdbf..09cf16438 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -11,13 +11,6 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/sdxl_train_control_net_lllite_alt.py b/sdxl_train_control_net_lllite_alt.py index 61ebfb581..757194a10 100644 --- a/sdxl_train_control_net_lllite_alt.py +++ b/sdxl_train_control_net_lllite_alt.py @@ -14,13 +14,6 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed import accelerate diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 2de57c0ac..8d3a81c3a 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,12 +1,5 @@ import argparse import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index f5cca17b2..123ca35a1 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -3,13 +3,6 @@ import regex import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/train_controlnet.py b/train_controlnet.py index 42da44125..988304f62 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -11,13 +11,6 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/train_db.py b/train_db.py index feb147787..6dde7e9bf 100644 --- a/train_db.py +++ b/train_db.py @@ -11,13 +11,6 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/train_network.py b/train_network.py index 200fc2cfe..f752607e9 100644 --- a/train_network.py +++ b/train_network.py @@ -12,13 +12,6 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 1c7b7fcb2..b65d524cf 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -7,13 +7,6 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2c5673be1..79c64cbeb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,13 +8,6 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler From b07f780a262e6a9773e36d1d10af8a3d3933b4b6 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 2 Oct 2023 19:29:20 +0300 Subject: [PATCH 08/10] Fix typo --- setup/setup_common.py | 2 +- setup/validate_requirements.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup/setup_common.py b/setup/setup_common.py index d9cacf35f..9a0ecdef3 100644 --- a/setup/setup_common.py +++ b/setup/setup_common.py @@ -240,7 +240,7 @@ def check_torch(): ]: if hasattr(torch, "xpu") and torch.xpu.is_available(): log.info( - f'Torch detected GPU: Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' + f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' ) else: log.info( diff --git a/setup/validate_requirements.py b/setup/validate_requirements.py index 664ccfd9b..9d17ffda2 100644 --- a/setup/validate_requirements.py +++ b/setup/validate_requirements.py @@ -78,7 +78,7 @@ def check_torch(): ]: if hasattr(torch, "xpu") and torch.xpu.is_available(): log.info( - f'Torch detected GPU: Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' + f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' ) else: log.info( From fd92e1cefdf976da5273b6c863c74eff4c433ca6 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 2 Oct 2023 19:34:09 +0300 Subject: [PATCH 09/10] Fix typo warning --- gui.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gui.sh b/gui.sh index a7980ccb9..debca42fa 100755 --- a/gui.sh +++ b/gui.sh @@ -69,7 +69,7 @@ else fi fi -#Set OneAPI environmet if it's not set by the user +#Set OneAPI if it's not set by the user if [[ "$@" == *"--use-ipex"* ]] then echo "Setting OneAPI environment" From 76473d0793108b983b072808402375c2ef5cb767 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 2 Oct 2023 19:51:14 +0300 Subject: [PATCH 10/10] Update IPEX index --- requirements_linux_ipex.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_linux_ipex.txt b/requirements_linux_ipex.txt index 28c0bc6b1..61d8a75f4 100644 --- a/requirements_linux_ipex.txt +++ b/requirements_linux_ipex.txt @@ -1,3 +1,3 @@ -torch==2.0.1a0+cxx11.abi torchvision==0.15.2a0+cxx11.abi intel_extension_for_pytorch==2.0.110+xpu -f https://developer.intel.com/ipex-whl-stable-xpu # no_verify leave this to specify not checking this a verification stage +torch==2.0.1a0+cxx11.abi torchvision==0.15.2a0+cxx11.abi intel_extension_for_pytorch==2.0.110+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ tensorboard==2.12.3 tensorflow==2.12.0 intel-extension-for-tensorflow[gpu] -r requirements.txt