Skip to content

Commit

Permalink
Fix DDP issues and Support DDP for all training scripts (bmaltais#448)
Browse files Browse the repository at this point in the history
* Fix DDP bugs

* Fix DDP bugs for finetune and db

* refactor model loader

* fix DDP network

* try to fix DDP network in train unet only

* remove unuse DDP import

* refactor DDP transform

* refactor DDP transform

* fix sample images bugs

* change DDP tranform location

* add autocast to train_db

* support DDP in XTI

* Clear DDP import
  • Loading branch information
Isotr0py authored May 3, 2023
1 parent a7485e4 commit e1143ca
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 37 deletions.
5 changes: 4 additions & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)

# verify load/save model formats
if load_stable_diffusion_format:
Expand Down Expand Up @@ -228,6 +228,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
Expand Down
34 changes: 33 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Union,
)
from accelerate import Accelerator
import gc
import glob
import math
import os
Expand All @@ -30,6 +31,7 @@

from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
Expand Down Expand Up @@ -2866,7 +2868,7 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype


def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
Expand Down Expand Up @@ -2895,6 +2897,36 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
return text_encoder, vae, unet, load_stable_diffusion_format


def transform_DDP(text_encoder, unet, network=None):
# Transform text_encoder, unet and network from DistributedDataParallel
return (encoder.module if type(encoder) == DDP else encoder for encoder in [text_encoder, unet, network])


def load_target_model(args, weight_dtype, accelerator):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()

text_encoder, unet, _ = transform_DDP(text_encoder, unet, network=None)

return text_encoder, vae, unet, load_stable_diffusion_format


def patch_accelerator_for_fp16_training(accelerator):
org_unscale_grads = accelerator.scaler._unscale_grads_

Expand Down
2 changes: 1 addition & 1 deletion networks/lora_interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def interrogate(args):
print(f"loading SD model: {args.sd_model}")
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)

print(f"loading LoRA: {args.model}")
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
Expand Down
8 changes: 6 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)

# verify load/save model formats
if load_stable_diffusion_format:
Expand Down Expand Up @@ -196,6 +196,9 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)

if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error

Expand Down Expand Up @@ -297,7 +300,8 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

if args.v_parameterization:
# v-parameterization training
Expand Down
36 changes: 6 additions & 30 deletions train_network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse
import gc
Expand Down Expand Up @@ -144,24 +143,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
for pi in range(accelerator.state.num_processes):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
Expand Down Expand Up @@ -279,6 +261,9 @@ def train(args):
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare (train_network here only)
text_encoder, unet, network = train_util.transform_DDP(text_encoder, unet, network)

unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
Expand All @@ -288,20 +273,11 @@ def train(args):
text_encoder.train()

# set top parameter requires_grad = True for gradient checkpointing works
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()

# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module


network.prepare_grad_etc(text_encoder, unet)

if not cache_latents:
Expand Down
5 changes: 4 additions & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

# Convert the init_word to token_id
if args.init_word is not None:
Expand Down Expand Up @@ -280,6 +280,9 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)

# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)

index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
Expand Down
5 changes: 4 additions & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

# Convert the init_word to token_id
if args.init_word is not None:
Expand Down Expand Up @@ -314,6 +314,9 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)

# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)

index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
Expand Down

0 comments on commit e1143ca

Please sign in to comment.