Skip to content

Commit

Permalink
add FLUX.1 support
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 18, 2024
1 parent 3cc5b8d commit ef70aa7
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 39 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ The command to install PyTorch is as follows:

### Recent Updates

Oct 19, 2024:

- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training.
- A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words.
- Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`.
- See [dataset configuration](docs/config_README-en.md) for the regularization dataset.
- Specify "number of training images x number of epochs >= number of regularization images x number of epochs".
- Specify a large value for `--prior_loss_weight` option (not dataset config). We recommend 10-1000.
- Set the loss in the training without using the regularization image to be close to the loss in the training using DOP.
```
[[datasets.subsets]]
image_dir = "path/to/image/dir"
num_repeats = 1
is_reg = true
custom_attributes.diff_output_preservation = true # Add this
```



Oct 13, 2024:

- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large.
Expand Down
123 changes: 84 additions & 39 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,33 +373,13 @@ def get_noise_pred_and_target(
if not args.apply_t5_attn_mask:
t5_attn_mask = None

if not args.split_mode:
# normal forward
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)

# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
if not args.split_mode:
# normal forward
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
Expand All @@ -408,18 +388,52 @@ def get_noise_pred_and_target(
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)

# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)

# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)

# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)

# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)

# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)

return model_pred

model_pred = call_dit(
img=packed_noisy_model_input,
img_ids=img_ids,
t5_out=t5_out,
txt_ids=txt_ids,
l_pooled=l_pooled,
timesteps=timesteps,
guidance_vec=guidance_vec,
t5_attn_mask=t5_attn_mask,
)

# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
Expand All @@ -430,6 +444,37 @@ def get_noise_pred_and_target(
# flow matching loss: this is different from SD3
target = noise - latents

# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)

if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
model_pred_prior = call_dit(
img=packed_noisy_model_input[diff_output_pr_indices],
img_ids=img_ids[diff_output_pr_indices],
t5_out=t5_out[diff_output_pr_indices],
txt_ids=txt_ids[diff_output_pr_indices],
l_pooled=l_pooled[diff_output_pr_indices],
timesteps=timesteps[diff_output_pr_indices],
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step

model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
args,
model_pred_prior,
noisy_model_input[diff_output_pr_indices],
sigmas[diff_output_pr_indices] if sigmas is not None else None,
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, None, weighting

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
Expand Down

0 comments on commit ef70aa7

Please sign in to comment.