diff --git a/flux_train.py b/flux_train.py index 79c44d7b4..afddc897f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -17,12 +17,14 @@ import os from multiprocessing import Value import time -from typing import List +from typing import List, Optional, Tuple, Union import toml from tqdm import tqdm import torch +import torch.nn as nn +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -466,45 +468,28 @@ def train(args): # memory efficient block swapping - def get_block_unit(dbl_blocks, sgl_blocks, index: int): - if index < len(dbl_blocks): - return (dbl_blocks[index],) - else: - index -= len(dbl_blocks) - index *= 2 - return (sgl_blocks[index], sgl_blocks[index + 1]) - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device): - def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc): - # print(f"Backward: Move block {bidx_to_cpu} to CPU") - for block in blocks_to_cpu: - block = block.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Backward: Move block {bidx_to_cuda} to CUDA") - for block in blocks_to_cuda: - block = block.to(dvc, non_blocking=True) - - torch.cuda.synchronize() - # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}") - return bidx_to_cpu, bidx_to_cuda - - blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu) - blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda) - - futures[block_idx_to_cuda] = thread_pool.submit( - move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device - ) + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + # start_time = time.perf_counter() + # print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA") + utils.swap_weight_devices(block_to_cpu, block_to_cuda) + # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - def wait_blocks_move(block_idx, futures): - if block_idx not in futures: + def wait_blocks_move(block_id, futures): + if block_id not in futures: return - # print(f"Backward: Wait for block {block_idx}") + # print(f"Backward: Wait for block {block_id}") # start_time = time.perf_counter() - future = futures.pop(block_idx) - future.result() - # print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") - # torch.cuda.synchronize() + future = futures.pop(block_id) + _, bidx_to_cuda = future.result() + assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}" + # print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s") # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") if args.fused_backward_pass: @@ -513,11 +498,11 @@ def wait_blocks_move(block_idx, futures): library.adafactor_fused.patch_adafactor_fused(optimizer) - blocks_to_swap = args.blocks_to_swap + double_blocks_to_swap = args.blocks_to_swap // 2 + single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - num_block_units = num_double_blocks + num_single_blocks // 2 - handled_unit_indices = set() + handled_block_ids = set() n = 1 # only asynchronous purpose, no need to increase this number # n = 2 @@ -530,28 +515,37 @@ def wait_blocks_move(block_idx, futures): if parameter.requires_grad: grad_hook = None - if blocks_to_swap: + if double_blocks_to_swap > 0 or single_blocks_to_swap > 0: is_double = param_name.startswith("double_blocks") is_single = param_name.startswith("single_blocks") - if is_double or is_single: + if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0: block_idx = int(param_name.split(".")[1]) - unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2 - if unit_idx not in handled_unit_indices: + block_id = (is_double, block_idx) # double or single, block index + if block_id not in handled_block_ids: # swap following (already backpropagated) block - handled_unit_indices.add(unit_idx) + handled_block_ids.add(block_id) # if n blocks were already backpropagated - num_blocks_propagated = num_block_units - unit_idx - 1 + if is_double: + num_blocks = num_double_blocks + blocks_to_swap = double_blocks_to_swap + else: + num_blocks = num_single_blocks + blocks_to_swap = single_blocks_to_swap + + # -1 for 0-based index, -1 for current block is not fully backpropagated yet + num_blocks_propagated = num_blocks - block_idx - 2 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + waiting = block_idx > 0 and block_idx <= blocks_to_swap + if swapping or waiting: - block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cpu = num_blocks - num_blocks_propagated block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = unit_idx - 1 + block_idx_to_wait = block_idx - 1 # create swap hook def create_swap_grad_hook( - bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool + is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool ): def __grad_hook(tensor: torch.Tensor): if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -559,24 +553,25 @@ def __grad_hook(tensor: torch.Tensor): optimizer.step_param(tensor, param_group) tensor.grad = None - # print(f"Backward: {uidx}, {swpng}, {wtng}") + # print( + # f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}" + # ) if swpng: submit_move_blocks( futures, thread_pool, bidx_to_cpu, bidx_to_cuda, - flux.double_blocks, - flux.single_blocks, - accelerator.device, + flux.double_blocks if is_dbl else flux.single_blocks, + (is_dbl, bidx_to_cuda), # wait for this block ) if wtng: - wait_blocks_move(bidx_to_wait, futures) + wait_blocks_move((is_dbl, bidx_to_wait), futures) return __grad_hook grad_hook = create_swap_grad_hook( - block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting + is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting ) if grad_hook is None: diff --git a/library/flux_models.py b/library/flux_models.py index 0bc1c02b9..48dea4fc9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -7,8 +7,9 @@ import math import os import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -923,7 +924,8 @@ def __init__(self, params: FluxParams): self.blocks_to_swap = None self.thread_pool: Optional[ThreadPoolExecutor] = None - self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2 + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) @property def device(self): @@ -963,14 +965,17 @@ def disable_gradient_checkpointing(self): def enable_block_swap(self, num_blocks: int): self.blocks_to_swap = num_blocks + self.double_blocks_to_swap = num_blocks // 2 + self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2 + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}." + ) n = 1 # async block swap. 1 is enough - # n = 2 - # n = max(1, os.cpu_count() // 2) self.thread_pool = ThreadPoolExecutor(max_workers=n) def move_to_device_except_swap_blocks(self, device: torch.device): - # assume model is on cpu + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: save_double_blocks = self.double_blocks save_single_blocks = self.single_blocks @@ -983,31 +988,55 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = save_double_blocks self.single_blocks = save_single_blocks - def get_block_unit(self, index: int): - if index < len(self.double_blocks): - return (self.double_blocks[index],) - else: - index -= len(self.double_blocks) - index *= 2 - return self.single_blocks[index], self.single_blocks[index + 1] + # def get_block_unit(self, index: int): + # if index < len(self.double_blocks): + # return (self.double_blocks[index],) + # else: + # index -= len(self.double_blocks) + # index *= 2 + # return self.single_blocks[index], self.single_blocks[index + 1] - def get_unit_index(self, is_double: bool, index: int): - if is_double: - return index - else: - return len(self.double_blocks) + index // 2 + # def get_unit_index(self, is_double: bool, index: int): + # if is_double: + # return index + # else: + # return len(self.double_blocks) + index // 2 def prepare_block_swap_before_forward(self): - # make: first n blocks are on cuda, and last n blocks are on cpu + # # make: first n blocks are on cuda, and last n blocks are on cpu + # if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # # raise ValueError("Block swap is not enabled.") + # return + # for i in range(self.num_block_units - self.blocks_to_swap): + # for b in self.get_block_unit(i): + # b.to(self.device) + # for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): + # for b in self.get_block_unit(i): + # b.to("cpu") + # clean_memory_on_device(self.device) + + # all blocks are on device, but some weights are on cpu + # make first n blocks weights on device, and last n blocks weights on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: # raise ValueError("Block swap is not enabled.") return - for i in range(self.num_block_units - self.blocks_to_swap): - for b in self.get_block_unit(i): - b.to(self.device) - for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): - for b in self.get_block_unit(i): - b.to("cpu") + + for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]: + b.to(self.device) + utils.weighs_to_device(b, self.device) # make sure weights are on device + for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]: + b.to(self.device) # move block to device first + utils.weighs_to_device(b, "cpu") # make sure weights are on cpu + torch.cuda.synchronize() + clean_memory_on_device(self.device) + + for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]: + b.to(self.device) + utils.weighs_to_device(b, self.device) # make sure weights are on device + for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]: + b.to(self.device) # move block to device first + utils.weighs_to_device(b, "cpu") # make sure weights are on cpu + torch.cuda.synchronize() clean_memory_on_device(self.device) def forward( @@ -1044,27 +1073,22 @@ def forward( for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - futures = {} - - def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda): - # print(f"Moving {bidx_to_cpu} to cpu.") - for block in blocks_to_cpu: - block.to("cpu", non_blocking=True) - torch.cuda.empty_cache() + # device = self.device - # print(f"Moving {bidx_to_cuda} to cuda.") - for block in blocks_to_cuda: - block.to(self.device, non_blocking=True) - - torch.cuda.synchronize() + def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + start_time = time.perf_counter() + # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") + utils.swap_weight_devices(block_to_cpu, block_to_cuda) # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - return block_idx_to_cpu, block_idx_to_cuda - blocks_to_cpu = self.get_block_unit(block_idx_to_cpu) - blocks_to_cuda = self.get_block_unit(block_idx_to_cuda) + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + return block_idx_to_cpu, block_idx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda) + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) def wait_for_blocks_move(block_idx, ftrs): if block_idx not in ftrs: @@ -1073,37 +1097,35 @@ def wait_for_blocks_move(block_idx, ftrs): # start_time = time.perf_counter() ftr = ftrs.pop(block_idx) ftr.result() - # torch.cuda.synchronize() - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + # print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds") + double_futures = {} for block_idx, block in enumerate(self.double_blocks): # print(f"Double block {block_idx}") - unit_idx = self.get_unit_index(is_double=True, index=block_idx) - wait_for_blocks_move(unit_idx, futures) + wait_for_blocks_move(block_idx, double_futures) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if unit_idx < self.blocks_to_swap: - block_idx_to_cpu = unit_idx - block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + if block_idx < self.double_blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx + future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda) + double_futures[block_idx_to_cuda] = future img = torch.cat((txt, img), 1) + single_futures = {} for block_idx, block in enumerate(self.single_blocks): # print(f"Single block {block_idx}") - unit_idx = self.get_unit_index(is_double=False, index=block_idx) - if block_idx % 2 == 0: - wait_for_blocks_move(unit_idx, futures) + wait_for_blocks_move(block_idx, single_futures) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap: - block_idx_to_cpu = unit_idx - block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + if block_idx < self.single_blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx + future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) + single_futures[block_idx_to_cuda] = future img = img[:, txt.shape[1] :, ...] diff --git a/library/utils.py b/library/utils.py index ca0f904d2..aed510074 100644 --- a/library/utils.py +++ b/library/utils.py @@ -6,6 +6,7 @@ import struct import torch +import torch.nn as nn from torchvision import transforms from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete @@ -93,6 +94,225 @@ def setup_logging(args=None, log_level=None, reset=False): # region PyTorch utils +# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ +# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): +# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: +# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.") +# # cpu_tensor = module_to_cuda.weight.data +# # cuda_tensor = module_to_cpu.weight.data +# # assert cuda_tensor.device.type == "cuda" +# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True) +# # torch.cuda.current_stream().synchronize() +# # cuda_tensor.copy_(cpu_tensor, non_blocking=True) +# # torch.cuda.current_stream().synchronize() +# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True) +# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor +# cuda_tensor_view = module_to_cpu.weight.data +# cpu_tensor_view = module_to_cuda.weight.data +# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone() +# module_to_cuda.weight.data = cuda_tensor_view +# module_to_cuda.weight.data.copy_(cpu_tensor_view) + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + stream_to_cpu = torch.cuda.Stream() + stream_to_cuda = torch.cuda.Stream() + + events = [] + with torch.cuda.stream(stream_to_cpu): + # cuda to offload + offloaded_weights = [] + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) + event = torch.cuda.Event() + event.record(stream=stream_to_cpu) + events.append(event) + + with torch.cuda.stream(stream_to_cuda): + # cpu to cuda + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events): + event.synchronize() + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + # offload to cpu + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip( + weight_swap_jobs, offloaded_weights + ): + module_to_cpu.weight.data = offloaded_weight + + stream_to_cuda.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + stream_to_cpu = torch.cuda.Stream() + stream_to_cuda = torch.cuda.Stream() + + # cuda to offload + events = [] + with torch.cuda.stream(stream_to_cpu): + offloaded_weights = [] + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream_to_cpu) + offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) + + event = torch.cuda.Event() + event.record(stream=stream_to_cpu) + events.append(event) + + # cpu to cuda + with torch.cuda.stream(stream_to_cuda): + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip( + weight_swap_jobs, events, offloaded_weights + ): + event.synchronize() + cuda_data_view.record_stream(stream_to_cuda) + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + module_to_cpu.weight.data = offloaded_weight + + stream_to_cuda.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + # torch.cuda.current_stream().wait_stream(stream_to_cuda) + # for job in weight_swap_jobs: + # job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor + + +def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")): + # one of the modules must have the tensor to offload + module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") + module_to_cpu.offloaded_weight.pin_memory() + offloaded_weight = ( + module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight + ) + assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu" + weight_swap_jobs.append( + (module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight) + ) + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to offload + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + cuda_data_view.record_stream(stream) + offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + # offload to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + module_to_cpu.weight.data = offloaded_weight + offloaded_weight = cpu_data_view + module_to_cpu.offloaded_weight = offloaded_weight + module_to_cuda.offloaded_weight = offloaded_weight + + stream.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")): + # one of the modules must have the tensor to cache + module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") + module_to_cpu.__cached_cpu_weight.pin_memory() + + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True) + module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True) + + torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish + torch.cuda.empty_cache() + + +# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ +# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): +# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: +# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda" +# weight_on_cuda = module_to_cpu.weight +# weight_on_cpu = module_to_cuda.weight +# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True) +# event = torch.cuda.current_stream().record_event() +# event.synchronize() +# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True) +# weight_on_cpu.data = cuda_to_cpu_data +# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad + +# module_to_cpu.weight = weight_on_cpu +# module_to_cuda.weight = weight_on_cuda + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: """ @@ -313,6 +533,7 @@ def _convert_float8(byte_tensor, dtype_str, shape): # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") + def load_safetensors( path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 ) -> dict[str, torch.Tensor]: @@ -336,7 +557,6 @@ def load_safetensors( return state_dict - # endregion # region Image utils