Skip to content

Commit

Permalink
faster block swap
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 5, 2024
1 parent 5e32ee2 commit 81c0c96
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 115 deletions.
107 changes: 51 additions & 56 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import os
from multiprocessing import Value
import time
from typing import List
from typing import List, Optional, Tuple, Union
import toml

from tqdm import tqdm

import torch
import torch.nn as nn
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -466,45 +468,28 @@ def train(args):

# memory efficient block swapping

def get_block_unit(dbl_blocks, sgl_blocks, index: int):
if index < len(dbl_blocks):
return (dbl_blocks[index],)
else:
index -= len(dbl_blocks)
index *= 2
return (sgl_blocks[index], sgl_blocks[index + 1])

def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc):
# print(f"Backward: Move block {bidx_to_cpu} to CPU")
for block in blocks_to_cpu:
block = block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()

# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
for block in blocks_to_cuda:
block = block.to(dvc, non_blocking=True)

torch.cuda.synchronize()
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}")
return bidx_to_cpu, bidx_to_cuda

blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu)
blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda)

futures[block_idx_to_cuda] = thread_pool.submit(
move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device
)
def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# start_time = time.perf_counter()
# print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event

block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]

futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)

def wait_blocks_move(block_idx, futures):
if block_idx not in futures:
def wait_blocks_move(block_id, futures):
if block_id not in futures:
return
# print(f"Backward: Wait for block {block_idx}")
# print(f"Backward: Wait for block {block_id}")
# start_time = time.perf_counter()
future = futures.pop(block_idx)
future.result()
# print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
# torch.cuda.synchronize()
future = futures.pop(block_id)
_, bidx_to_cuda = future.result()
assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}"
# print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s")
# print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s")

if args.fused_backward_pass:
Expand All @@ -513,11 +498,11 @@ def wait_blocks_move(block_idx, futures):

library.adafactor_fused.patch_adafactor_fused(optimizer)

blocks_to_swap = args.blocks_to_swap
double_blocks_to_swap = args.blocks_to_swap // 2
single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set()
handled_block_ids = set()

n = 1 # only asynchronous purpose, no need to increase this number
# n = 2
Expand All @@ -530,53 +515,63 @@ def wait_blocks_move(block_idx, futures):
if parameter.requires_grad:
grad_hook = None

if blocks_to_swap:
if double_blocks_to_swap > 0 or single_blocks_to_swap > 0:
is_double = param_name.startswith("double_blocks")
is_single = param_name.startswith("single_blocks")
if is_double or is_single:
if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0:
block_idx = int(param_name.split(".")[1])
unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2
if unit_idx not in handled_unit_indices:
block_id = (is_double, block_idx) # double or single, block index
if block_id not in handled_block_ids:
# swap following (already backpropagated) block
handled_unit_indices.add(unit_idx)
handled_block_ids.add(block_id)

# if n blocks were already backpropagated
num_blocks_propagated = num_block_units - unit_idx - 1
if is_double:
num_blocks = num_double_blocks
blocks_to_swap = double_blocks_to_swap
else:
num_blocks = num_single_blocks
blocks_to_swap = single_blocks_to_swap

# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = num_blocks - block_idx - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap
waiting = unit_idx > 0 and unit_idx <= blocks_to_swap
waiting = block_idx > 0 and block_idx <= blocks_to_swap

if swapping or waiting:
block_idx_to_cpu = num_block_units - num_blocks_propagated
block_idx_to_cpu = num_blocks - num_blocks_propagated
block_idx_to_cuda = blocks_to_swap - num_blocks_propagated
block_idx_to_wait = unit_idx - 1
block_idx_to_wait = block_idx - 1

# create swap hook
def create_swap_grad_hook(
bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool
is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool
):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None

# print(f"Backward: {uidx}, {swpng}, {wtng}")
# print(
# f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}"
# )
if swpng:
submit_move_blocks(
futures,
thread_pool,
bidx_to_cpu,
bidx_to_cuda,
flux.double_blocks,
flux.single_blocks,
accelerator.device,
flux.double_blocks if is_dbl else flux.single_blocks,
(is_dbl, bidx_to_cuda), # wait for this block
)
if wtng:
wait_blocks_move(bidx_to_wait, futures)
wait_blocks_move((is_dbl, bidx_to_wait), futures)

return __grad_hook

grad_hook = create_swap_grad_hook(
block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting
is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting
)

if grad_hook is None:
Expand Down
138 changes: 80 additions & 58 deletions library/flux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import math
import os
import time
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from library import utils
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -923,7 +924,8 @@ def __init__(self, params: FluxParams):
self.blocks_to_swap = None

self.thread_pool: Optional[ThreadPoolExecutor] = None
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)

@property
def device(self):
Expand Down Expand Up @@ -963,14 +965,17 @@ def disable_gradient_checkpointing(self):

def enable_block_swap(self, num_blocks: int):
self.blocks_to_swap = num_blocks
self.double_blocks_to_swap = num_blocks // 2
self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}."
)

n = 1 # async block swap. 1 is enough
# n = 2
# n = max(1, os.cpu_count() // 2)
self.thread_pool = ThreadPoolExecutor(max_workers=n)

def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
Expand All @@ -983,31 +988,55 @@ def move_to_device_except_swap_blocks(self, device: torch.device):
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks

def get_block_unit(self, index: int):
if index < len(self.double_blocks):
return (self.double_blocks[index],)
else:
index -= len(self.double_blocks)
index *= 2
return self.single_blocks[index], self.single_blocks[index + 1]
# def get_block_unit(self, index: int):
# if index < len(self.double_blocks):
# return (self.double_blocks[index],)
# else:
# index -= len(self.double_blocks)
# index *= 2
# return self.single_blocks[index], self.single_blocks[index + 1]

def get_unit_index(self, is_double: bool, index: int):
if is_double:
return index
else:
return len(self.double_blocks) + index // 2
# def get_unit_index(self, is_double: bool, index: int):
# if is_double:
# return index
# else:
# return len(self.double_blocks) + index // 2

def prepare_block_swap_before_forward(self):
# make: first n blocks are on cuda, and last n blocks are on cpu
# # make: first n blocks are on cuda, and last n blocks are on cpu
# if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# # raise ValueError("Block swap is not enabled.")
# return
# for i in range(self.num_block_units - self.blocks_to_swap):
# for b in self.get_block_unit(i):
# b.to(self.device)
# for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
# for b in self.get_block_unit(i):
# b.to("cpu")
# clean_memory_on_device(self.device)

# all blocks are on device, but some weights are on cpu
# make first n blocks weights on device, and last n blocks weights on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return
for i in range(self.num_block_units - self.blocks_to_swap):
for b in self.get_block_unit(i):
b.to(self.device)
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
for b in self.get_block_unit(i):
b.to("cpu")

for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)

for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]:
b.to(self.device)
utils.weighs_to_device(b, self.device) # make sure weights are on device
for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]:
b.to(self.device) # move block to device first
utils.weighs_to_device(b, "cpu") # make sure weights are on cpu
torch.cuda.synchronize()
clean_memory_on_device(self.device)

def forward(
Expand Down Expand Up @@ -1044,27 +1073,22 @@ def forward(
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
futures = {}

def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
# print(f"Moving {bidx_to_cpu} to cpu.")
for block in blocks_to_cpu:
block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# device = self.device

# print(f"Moving {bidx_to_cuda} to cuda.")
for block in blocks_to_cuda:
block.to(self.device, non_blocking=True)

torch.cuda.synchronize()
def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
start_time = time.perf_counter()
# print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.")
utils.swap_weight_devices(block_to_cpu, block_to_cuda)
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
return block_idx_to_cpu, block_idx_to_cuda

blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
return block_idx_to_cpu, block_idx_to_cuda # , event

block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)

def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
Expand All @@ -1073,37 +1097,35 @@ def wait_for_blocks_move(block_idx, ftrs):
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# torch.cuda.synchronize()
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
# print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds")

double_futures = {}
for block_idx, block in enumerate(self.double_blocks):
# print(f"Double block {block_idx}")
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
wait_for_blocks_move(unit_idx, futures)
wait_for_blocks_move(block_idx, double_futures)

img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)

if unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
if block_idx < self.double_blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx
future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda)
double_futures[block_idx_to_cuda] = future

img = torch.cat((txt, img), 1)

single_futures = {}
for block_idx, block in enumerate(self.single_blocks):
# print(f"Single block {block_idx}")
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
if block_idx % 2 == 0:
wait_for_blocks_move(unit_idx, futures)
wait_for_blocks_move(block_idx, single_futures)

img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)

if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
if block_idx < self.single_blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx
future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda)
single_futures[block_idx_to_cuda] = future

img = img[:, txt.shape[1] :, ...]

Expand Down
Loading

0 comments on commit 81c0c96

Please sign in to comment.