diff --git a/examples/cogvideo/README.md b/examples/cogvideo/README.md new file mode 100644 index 0000000000000..398ae95431504 --- /dev/null +++ b/examples/cogvideo/README.md @@ -0,0 +1,228 @@ +# LoRA finetuning example for CogVideoX + +Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. + +In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: + +- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). +- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. + +At the moment, LoRA finetuning has only been tested for [CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b). + +## Data Preparation + +The training scripts accepts data in two formats. + +**First data format** + +Two files where one file contains line-separated prompts and another file contains line-separated paths to video data (the path to video files must be relative to the path you pass when specifying `--instance_data_root`). Let's take a look at an example to understand this better! + +Assume you've specified `--instance_data_root` as `/dataset`, and that this directory contains the files: `prompts.txt` and `videos.txt`. + +The `prompts.txt` file should contain line-separated prompts: + +``` +A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction. +A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures. +... +``` + +The `videos.txt` file should contain line-separate paths to video files. Note that the path should be _relative_ to the `--instance_data_root` directory. + +``` +videos/00000.mp4 +videos/00001.mp4 +... +``` + +Overall, this is how your dataset would look like if you ran the `tree` command on the dataset root directory: + +``` +/dataset +├── prompts.txt +├── videos.txt +├── videos + ├── videos/00000.mp4 + ├── videos/00001.mp4 + ├── ... +``` + +When using this format, the `--caption_column` must be `prompts.txt` and `--video_column` must be `videos.txt`. + +**Second data format** + +You could use a single CSV file. For the sake of this example, assume you have a `metadata.csv` file. The expected format is: + +``` +, +"""A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.""","""00000.mp4""" +"""A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.""","""00001.mp4""" +... +``` + +In this case, the `--instance_data_root` should be the location where the videos are stored and `--dataset_name` should be either a path to local folder or `load_dataset` compatible hosted HF Dataset Repository or URL. Assuming you have videos of your Minecraft gameplay at `https://huggingface.co/datasets/my-awesome-username/minecraft-videos`, you would have to specify `my-awesome-username/minecraft-videos`. + +When using this format, the `--caption_column` must be `` and `--video_column` must be ``. + +You are not strictly restricted to the CSV format. As long as the `load_dataset` method supports the file format to load a basic `` and ``, you should be good to go. The reason for going through these dataset organization gymnastics for loading video data is because we found `load_dataset` from the datasets library to not fully support all kinds of video formats. This will undoubtedly be improved in the future. + +>![NOTE] +> CogVideoX works best with long and descriptive LLM-augmented prompts for video generation. We recommend pre-processing your videos by first generating a summary using a VLM and then augmenting the prompts with an LLM. To generate the above captions, we use [MiniCPM-V-26](https://huggingface.co/openbmb/MiniCPM-V-2_6) and [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct). A very barebones and no-frills example for this is available [here](https://gist.github.com/a-r-r-o-w/4dee20250e82f4e44690a02351324a4a). The official recommendation for augmenting prompts is [ChatGLM](https://huggingface.co/THUDM?search_models=chatglm) and a length of 50-100 words is considered good. + +>![NOTE] +> It is expected that your dataset is already pre-processed. If not, some basic pre-processing can be done by playing with the following parameters: +> `--height`, `--width`, `--fps`, `--max_num_frames`, `--skip_frames_start` and `--skip_frames_end`. +> Presently, all videos in your dataset should contain the same number of video frames when using a training batch size > 1. + + + +## Training + +You need to setup your development environment by installing the necessary requirements. The following packages are required: +- Torch 2.0 or above based on the training features you are utilizing (might require latest or nightly versions for quantized/deepspeed training) +- `pip install diffusers transformers accelerate peft huggingface_hub` for all things modeling and training related +- `pip install datasets decord` for loading video training data +- `pip install bitsandbytes` for using 8-bit Adam or AdamW optimizers for memory-optimized training +- `pip install wandb` optionally for monitoring training logs +- `pip install deepspeed` optionally for [DeepSpeed](https://github.com/microsoft/DeepSpeed) training +- `pip install prodigyopt` optionally if you would like to use the Prodigy optimizer for training + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + +If you would like to push your model to the HF Hub after training is completed with a neat model card, make sure you're logged in: + +``` +huggingface-cli login + +# Alternatively, you could upload your model manually using: +# huggingface-cli upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora +``` + +Make sure your data is prepared as described in [Data Preparation](#data-preparation). When ready, you can begin training! + +Assuming you are training on 50 videos of a similar concept, we have found 1500-2000 steps to work well. The official recommendation, however, is 100 videos with a total of 4000 steps. Assuming you are training on a single GPU with a `--train_batch_size` of `1`: +- 1500 steps on 50 videos would correspond to `30` training epochs +- 4000 steps on 100 videos would correspond to `40` training epochs + +```bash +#!/bin/bash + +GPU_IDS="0" + +accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \ + --pretrained_model_name_or_path THUDM/CogVideoX-2b \ + --cache_dir \ + --instance_data_root \ + --dataset_name my-awesome-name/my-awesome-dataset \ + --caption_column \ + --video_column \ + --id_token \ + --validation_prompt " Spiderman swinging over buildings:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 10 \ + --seed 42 \ + --rank 64 \ + --lora_alpha 64 \ + --mixed_precision fp16 \ + --output_dir /raid/aryan/cogvideox-lora \ + --height 480 --width 720 --fps 8 --max_num_frames 49 --skip_frames_start 0 --skip_frames_end 0 \ + --train_batch_size 1 \ + --num_train_epochs 30 \ + --checkpointing_steps 1000 \ + --gradient_accumulation_steps 1 \ + --learning_rate 1e-3 \ + --lr_scheduler cosine_with_restarts \ + --lr_warmup_steps 200 \ + --lr_num_cycles 1 \ + --enable_slicing \ + --enable_tiling \ + --optimizer Adam \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --max_grad_norm 1.0 \ + --report_to wandb +``` + +To better track our training experiments, we're using the following flags in the command above: +* `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +Note that setting the `` is not necessary. From some limited experimentation, we found it to work better (as it resembles [Dreambooth](https://huggingface.co/docs/diffusers/en/training/dreambooth) like training) than without. When provided, the ID_TOKEN is appended to the beginning of each prompt. So, if your ID_TOKEN was `"DISNEY"` and your prompt was `"Spiderman swinging over buildings"`, the effective prompt used in training would be `"DISNEY Spiderman swinging over buildings"`. When not provided, you would either be training without any such additional token or could augment your dataset to apply the token where you wish before starting the training. + +> [!TIP] +> You can pass `--use_8bit_adam` to reduce the memory requirements of training. + +> [!IMPORTANT] +> The following settings have been tested at the time of adding CogVideoX LoRA training support: +> - Our testing was primarily done on CogVideoX-2b. We will work on CogVideoX-5b and CogVideoX-5b-I2V soon +> - One dataset comprised of 70 training videos of resolutions `200 x 480 x 720` (F x H x W). From this, by using frame skipping in data preprocessing, we created two smaller 49-frame and 16-frame datasets for faster experimentation and because the maximum limit recommended by the CogVideoX team is 49 frames. Out of the 70 videos, we created three groups of 10, 25 and 50 videos. All videos were similar in nature of the concept being trained. +> - 25+ videos worked best for training new concepts and styles. +> - We found that it is better to train with an identifier token that can be specified as `--id_token`. This is similar to Dreambooth-like training but normal finetuning without such a token works too. +> - Trained concept seemed to work decently well when combined with completely unrelated prompts. We expect even better results if CogVideoX-5B is finetuned. +> - The original repository uses a `lora_alpha` of `1`. We found this not suitable in many runs, possibly due to difference in modeling backends and training settings. Our recommendation is to set to the `lora_alpha` to either `rank` or `rank // 2`. +> - If you're training on data whose captions generate bad results with the original model, a `rank` of 64 and above is good and also the recommendation by the team behind CogVideoX. If the generations are already moderately good on your training captions, a `rank` of 16/32 should work. We found that setting the rank too low, say `4`, is not ideal and doesn't produce promising results. +> - The authors of CogVideoX recommend 4000 training steps and 100 training videos overall to achieve the best result. While that might yield the best results, we found from our limited experimentation that 2000 steps and 25 videos could also be sufficient. +> - When using the Prodigy opitimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`. +> - The recommended learning rate by the CogVideoX authors and from our experimentation with Adam/AdamW is between `1e-3` and `1e-4` for a dataset of 25+ videos. +> +> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data. + + + +## Inference + +Once you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`. + +```python +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16) +# pipe.load_lora_weights("/path/to/lora/weights", adapter_name="cogvideox-lora") # Or, +pipe.load_lora_weights("my-awesome-hf-username/my-awesome-lora-name", adapter_name="cogvideox-lora") # If loading from the HF Hub +pipe.to("cuda") + +# Assuming lora_alpha=32 and rank=64 for training. If different, set accordingly +pipe.set_adapters(["cogvideox-lora"], [32 / 64]) + +prompt = ( + "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The " + "panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + "atmosphere of this unique musical performance" +) +frames = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0] +export_to_video(frames, "output.mp4", fps=8) +``` diff --git a/examples/cogvideo/requirements.txt b/examples/cogvideo/requirements.txt new file mode 100644 index 0000000000000..c2238804be9f4 --- /dev/null +++ b/examples/cogvideo/requirements.txt @@ -0,0 +1,10 @@ +accelerate>=0.31.0 +torchvision +transformers>=4.41.2 +ftfy +tensorboard +Jinja2 +peft>=0.11.1 +sentencepiece +decord>=0.6.0 +imageio-ffmpeg \ No newline at end of file diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py new file mode 100644 index 0000000000000..137f3222f6d90 --- /dev/null +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -0,0 +1,1544 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import os +import shutil +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer + +import diffusers +from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.optimization import get_scheduler +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid +from diffusers.training_utils import ( + cast_training_params, + clear_objs_and_retain_memory, +) +from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0.dev0") + +logger = get_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") + + # Model information + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + # Dataset information + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.", + ) + parser.add_argument( + "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + # Validation + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`." + ), + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=6, + help="The guidance scale to use while sampling validation videos.", + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + default=False, + help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", + ) + + # Training information + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=128, + help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogvideox-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="All input videos are resized to this height.", + ) + parser.add_argument( + "--width", + type=int, + default=720, + help="All input videos are resized to this width.", + ) + parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") + parser.add_argument( + "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." + ) + parser.add_argument( + "--skip_frames_start", + type=int, + default=0, + help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", + ) + parser.add_argument( + "--skip_frames_end", + type=int, + default=0, + help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip videos horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + + # Optimizer + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "prodigy"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.") + parser.add_argument( + "--prodigy_safeguard_warmup", + action="store_true", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) + + # Other information + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default=None, + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + + return parser.parse_args() + + +class VideoDataset(Dataset): + def __init__( + self, + instance_data_root: Optional[str] = None, + dataset_name: Optional[str] = None, + dataset_config_name: Optional[str] = None, + caption_column: str = "text", + video_column: str = "video", + height: int = 480, + width: int = 720, + fps: int = 8, + max_num_frames: int = 49, + skip_frames_start: int = 0, + skip_frames_end: int = 0, + cache_dir: Optional[str] = None, + id_token: Optional[str] = None, + ) -> None: + super().__init__() + + self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None + self.dataset_name = dataset_name + self.dataset_config_name = dataset_config_name + self.caption_column = caption_column + self.video_column = video_column + self.height = height + self.width = width + self.fps = fps + self.max_num_frames = max_num_frames + self.skip_frames_start = skip_frames_start + self.skip_frames_end = skip_frames_end + self.cache_dir = cache_dir + self.id_token = id_token or "" + + if dataset_name is not None: + self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() + else: + self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() + + self.num_instance_videos = len(self.instance_video_paths) + if self.num_instance_videos != len(self.instance_prompts): + raise ValueError( + f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + self.instance_videos = self._preprocess_data() + + def __len__(self): + return self.num_instance_videos + + def __getitem__(self, index): + return { + "instance_prompt": self.id_token + self.instance_prompts[index], + "instance_video": self.instance_videos[index], + } + + def _load_dataset_from_hub(self): + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_root instead." + ) + + # Downloading and loading a dataset from the hub. See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + self.dataset_name, + self.dataset_config_name, + cache_dir=self.cache_dir, + ) + column_names = dataset["train"].column_names + + if self.video_column is None: + video_column = column_names[0] + logger.info(f"`video_column` defaulting to {video_column}") + else: + video_column = self.video_column + if video_column not in column_names: + raise ValueError( + f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if self.caption_column is None: + caption_column = column_names[1] + logger.info(f"`caption_column` defaulting to {caption_column}") + else: + caption_column = self.caption_column + if self.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + instance_prompts = dataset["train"][caption_column] + instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]] + + return instance_prompts, instance_videos + + def _load_dataset_from_local_path(self): + if not self.instance_data_root.exists(): + raise ValueError("Instance videos root folder does not exist") + + prompt_path = self.instance_data_root.joinpath(self.caption_column) + video_path = self.instance_data_root.joinpath(self.video_column) + + if not prompt_path.exists() or not prompt_path.is_file(): + raise ValueError( + "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts." + ) + if not video_path.exists() or not video_path.is_file(): + raise ValueError( + "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory." + ) + + with open(prompt_path, "r", encoding="utf-8") as file: + instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] + with open(video_path, "r", encoding="utf-8") as file: + instance_videos = [ + self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0 + ] + + if any(not path.is_file() for path in instance_videos): + raise ValueError( + "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return instance_prompts, instance_videos + + def _preprocess_data(self): + try: + import decord + except ImportError: + raise ImportError( + "The `decord` package is required for loading the video dataset. Install with `pip install decord`" + ) + + decord.bridge.set_bridge("torch") + + videos = [] + train_transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), + ] + ) + + for filename in self.instance_video_paths: + video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) + video_num_frames = len(video_reader) + + start_frame = min(self.skip_frames_start, video_num_frames) + end_frame = max(0, video_num_frames - self.skip_frames_end) + if end_frame <= start_frame: + frames = video_reader.get_batch([start_frame]) + elif end_frame - start_frame <= self.max_num_frames: + frames = video_reader.get_batch(list(range(start_frame, end_frame))) + else: + indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) + frames = video_reader.get_batch(indices) + + # Ensure that we don't go over the limit + frames = frames[: self.max_num_frames] + selected_num_frames = frames.shape[0] + + # Choose first (4k + 1) frames as this is how many is required by the VAE + remainder = (3 + (selected_num_frames % 4)) % 4 + if remainder != 0: + frames = frames[:-remainder] + selected_num_frames = frames.shape[0] + + assert (selected_num_frames - 1) % 4 == 0 + + # Training transforms + frames = frames.float() + frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) + videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] + + return videos + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}} + ) + + model_description = f""" +# CogVideoX LoRA - {repo_id} + + + +## Model description + +These are {repo_id} LoRA weights for {base_model}. + +The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +Was LoRA for the text encoder enabled? No. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import CogVideoXPipeline +import torch + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"]) + +# The LoRA adapter weights are determined by what was used for training. +# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. +# It can be made lower or higher from what was used in training to decrease or amplify the effect +# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. +pipe.set_adapters(["cogvideox-lora"], [32 / 64]) + +video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "lora", + "cogvideox", + "cogvideox-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipe, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipe.scheduler.config: + variance_type = pipe.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) + pipe = pipe.to(accelerator.device) + # pipe.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + clear_objs_and_retain_memory([pipe]) + + return videos + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + +def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy"] + if args.optimizer not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if args.optimizer.lower() == "adamw": + optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "adam": + optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + return optimizer + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + # We only train the additional adapter LoRA layers + text_encoder.requires_grad_(False) + transformer.requires_grad_(False) + vae.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.float16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + CogVideoXPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) + + # Dataset and DataLoader + train_dataset = VideoDataset( + instance_data_root=args.instance_data_root, + dataset_name=args.dataset_name, + dataset_config_name=args.dataset_config_name, + caption_column=args.caption_column, + video_column=args.video_column, + height=args.height, + width=args.width, + fps=args.fps, + max_num_frames=args.max_num_frames, + skip_frames_start=args.skip_frames_start, + skip_frames_end=args.skip_frames_end, + cache_dir=args.cache_dir, + id_token=args.id_token, + ) + + def encode_video(video): + video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) + video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(video).latent_dist + return latent_dist + + train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] + + def collate_fn(examples): + videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + videos = torch.cat(videos) + videos = videos.to(memory_format=torch.contiguous_format).float() + + return { + "videos": videos, + "prompts": prompts, + } + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + logger.info("***** Running training *****") + logger.info(f" Num trainable parameters = {num_trainable_parameters}") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + + with accelerator.accumulate(models_to_accumulate): + model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] + prompts = batch["prompts"] + + # encode prompts + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + + # Sample noise that will be added to the latents + noise = torch.randn_like(model_input) + batch_size, num_frames, num_channels, height, width = model_input.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device + ) + timesteps = timesteps.long() + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=args.height, + width=args.width, + num_frames=num_frames, + vae_scale_factor_spatial=vae_scale_factor_spatial, + patch_size=model_config.patch_size, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps) + + alphas_cumprod = scheduler.alphas_cumprod[timesteps] + weights = 1 / (1 - alphas_cumprod) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = model_input + + loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + # Create pipeline + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + validation_outputs = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + CogVideoXPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + + # Final test inference + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") + pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + validation_outputs.extend(video) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index bccd37ddc42fe..bf72122168455 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -67,6 +67,7 @@ def text_encoder_attn_modules(text_encoder): "StableDiffusionXLLoraLoaderMixin", "LoraLoaderMixin", "FluxLoraLoaderMixin", + "CogVideoXLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] @@ -84,6 +85,7 @@ def text_encoder_attn_modules(text_encoder): from .ip_adapter import IPAdapterMixin from .lora_pipeline import ( AmusedLoraLoaderMixin, + CogVideoXLoraLoaderMixin, FluxLoraLoaderMixin, LoraLoaderMixin, SD3LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7d644d6841533..ba1435a8cbdc6 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2276,6 +2276,339 @@ def save_lora_weights( ) +class CogVideoXLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components) + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 89d6a28b14dd9..d1c6721512faf 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -33,6 +33,7 @@ "UNetMotionModel": _maybe_expand_lora_scales, "SD3Transformer2DModel": lambda model_cls, weights: weights, "FluxTransformer2DModel": lambda model_cls, weights: weights, + "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 6f19e132eae51..821da6d032d59 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -19,7 +19,8 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 @@ -152,7 +153,7 @@ def forward( return hidden_states, encoder_hidden_states -class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). @@ -411,8 +412,24 @@ def forward( timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding @@ -481,6 +498,10 @@ def custom_forward(*inputs): output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 3af47c1774377..02497e77edb7b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -15,12 +15,13 @@ import inspect import math -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline @@ -136,7 +137,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogVideoXPipeline(DiffusionPipeline): +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for text-to-video generation using CogVideoX. @@ -462,6 +463,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -487,6 +492,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: str = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -548,6 +554,10 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -592,6 +602,7 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Default call parameters @@ -673,6 +684,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred.float() diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 16686d1ab7ac7..649199829cf44 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -15,21 +15,19 @@ import inspect import math -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from PIL import Image from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from ...utils import ( - logging, - replace_example_docstring, -) +from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from .pipeline_output import CogVideoXPipelineOutput @@ -161,7 +159,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class CogVideoXVideoToVideoPipeline(DiffusionPipeline): +class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for video-to-video generation using CogVideoX. @@ -541,6 +539,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -567,6 +569,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: str = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -627,6 +630,10 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -667,6 +674,7 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Default call parameters @@ -755,6 +763,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred.float() diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py new file mode 100644 index 0000000000000..17b1cc8e764a4 --- /dev/null +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -0,0 +1,182 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + floats_tensor, + is_peft_available, + require_peft_backend, + skip_mps, + torch_device, +) + + +if is_peft_available(): + pass + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 + + +@require_peft_backend +class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = CogVideoXPipeline + scheduler_cls = CogVideoXDPMScheduler + scheduler_kwargs = {"timestep_spacing": "trailing"} + + transformer_kwargs = { + "num_attention_heads": 4, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "time_embed_dim": 2, + "text_embed_dim": 32, + "num_layers": 1, + "sample_width": 16, + "sample_height": 16, + "sample_frames": 9, + "patch_size": 2, + "temporal_compression_ratio": 4, + "max_text_seq_length": 16, + } + transformer_cls = CogVideoXTransformer3DModel + vae_kwargs = { + "in_channels": 3, + "out_channels": 3, + "down_block_types": ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + "up_block_types": ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + "block_out_channels": (8, 8, 8, 8), + "latent_channels": 4, + "layers_per_block": 1, + "norm_num_groups": 2, + "temporal_compression_ratio": 4, + } + vae_cls = AutoencoderKLCogVideoX + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + text_encoder_target_modules = ["q", "k", "v", "o"] + + @property + def output_shape(self): + return (1, 9, 16, 16, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + num_frames = 9 + num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 + sizes = (2, 2) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "dance monkey", + "num_frames": num_frames, + "num_inference_steps": 4, + "guidance_scale": 6.0, + # Cannot reduce because convolution kernel becomes bigger than sample + "height": 16, + "width": 16, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + @skip_mps + def test_lora_fuse_nan(self): + scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler] + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) + + out = pipe( + "test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np" + )[0] + + self.assertTrue(np.isnan(out).all()) + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=5e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=5e-3) + + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_lora_save_load(self): + pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 283b9f534766c..adf7cb24470f7 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -85,8 +85,11 @@ class PeftLoraLoaderMixinTests: unet_kwargs = None transformer_cls = None transformer_kwargs = None + vae_cls = AutoencoderKL vae_kwargs = None + text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + def get_dummy_components(self, scheduler_cls=None, use_dora=False): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") @@ -105,7 +108,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): scheduler = scheduler_cls(**self.scheduler_kwargs) torch.manual_seed(0) - vae = AutoencoderKL(**self.vae_kwargs) + vae = self.vae_cls(**self.vae_kwargs) text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id) tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) @@ -121,7 +124,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): text_lora_config = LoraConfig( r=rank, lora_alpha=rank, - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + target_modules=self.text_encoder_target_modules, init_lora_weights=False, use_dora=use_dora, ) @@ -202,6 +205,9 @@ def test_simple_inference(self): """ Tests a simple inference and makes sure it works as expected """ + # TODO(aryan): Some of the assumptions made here in many different tests are incorrect for CogVideoX. + # For example, we need to test with CogVideoXDDIMScheduler and CogVideoDPMScheduler instead of DDIMScheduler + # and LCMScheduler, which are not supported by it. scheduler_classes = ( [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) @@ -212,7 +218,7 @@ def test_simple_inference(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs() - output_no_lora = pipe(**inputs).images + output_no_lora = pipe(**inputs)[0] self.assertTrue(output_no_lora.shape == self.output_shape) def test_simple_inference_with_text_lora(self): @@ -230,7 +236,7 @@ def test_simple_inference_with_text_lora(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) @@ -244,7 +250,7 @@ def test_simple_inference_with_text_lora(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) @@ -257,6 +263,13 @@ def test_simple_inference_with_text_lora_and_scale(self): scheduler_classes = ( [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) + call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() + for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: + if possible_attention_kwargs in call_signature_keys: + attention_kwargs_name = possible_attention_kwargs + break + assert attention_kwargs_name is not None + for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) @@ -264,7 +277,7 @@ def test_simple_inference_with_text_lora_and_scale(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) @@ -278,32 +291,22 @@ def test_simple_inference_with_text_lora_and_scale(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) - if self.unet_kwargs is not None: - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} - ).images - else: - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5} - ).images + attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} + output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), "Lora + scale should change the output", ) - if self.unet_kwargs is not None: - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} - ).images - else: - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0} - ).images + attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} + output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), "Lora + 0 scale should lead to same result as no LoRA", @@ -324,7 +327,7 @@ def test_simple_inference_with_text_lora_fused(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) @@ -347,7 +350,7 @@ def test_simple_inference_with_text_lora_fused(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images + ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse( np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) @@ -367,11 +370,14 @@ def test_simple_inference_with_text_lora_unloaded(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules @@ -394,7 +400,7 @@ def test_simple_inference_with_text_lora_unloaded(self): "Lora not correctly unloaded in text encoder 2", ) - ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images + ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output", @@ -414,11 +420,14 @@ def test_simple_inference_with_text_lora_save_load(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -427,7 +436,7 @@ def test_simple_inference_with_text_lora_save_load(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) @@ -461,7 +470,7 @@ def test_simple_inference_with_text_lora_save_load(self): pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: @@ -500,7 +509,7 @@ def test_simple_inference_with_partial_text_lora(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) @@ -527,7 +536,7 @@ def test_simple_inference_with_partial_text_lora(self): } ) - output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) @@ -536,7 +545,7 @@ def test_simple_inference_with_partial_text_lora(self): pipe.unload_lora_weights() pipe.load_lora_weights(state_dict) - output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), "Removing adapters should change the output", @@ -556,7 +565,7 @@ def test_simple_inference_save_pretrained(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) @@ -569,7 +578,7 @@ def test_simple_inference_save_pretrained(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: pipe.save_pretrained(tmpdirname) @@ -589,7 +598,7 @@ def test_simple_inference_save_pretrained(self): "Lora not correctly set in text encoder 2", ) - images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images + images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), @@ -603,9 +612,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): scheduler_classes = ( [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) @@ -613,16 +619,20 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config) else: pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet") @@ -633,10 +643,14 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: - text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) + text_encoder_state_dict = ( + get_peft_model_state_dict(pipe.text_encoder) + if "text_encoder" in self.pipeline_class._lora_loadable_modules + else None + ) if self.unet_kwargs is not None: denoiser_state_dict = get_peft_model_state_dict(pipe.unet) @@ -645,10 +659,12 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): saving_kwargs = { "save_directory": tmpdirname, - "text_encoder_lora_layers": text_encoder_state_dict, "safe_serialization": False, } + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict}) + if self.unet_kwargs is not None: saving_kwargs.update({"unet_lora_layers": denoiser_state_dict}) else: @@ -666,8 +682,13 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -690,6 +711,13 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): scheduler_classes = ( [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) + call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() + for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: + if possible_attention_kwargs in call_signature_keys: + attention_kwargs_name = possible_attention_kwargs + break + assert attention_kwargs_name is not None + for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) @@ -697,15 +725,20 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config) else: pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -716,41 +749,32 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) - if self.unet_kwargs is not None: - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} - ).images - else: - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5} - ).images + attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} + output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), "Lora + scale should change the output", ) - if self.unet_kwargs is not None: - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} - ).images - else: - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0} - ).images + attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} + output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), "Lora + 0 scale should lead to same result as no LoRA", ) - self.assertTrue( - pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, - "The scaling parameter has not been correctly restored!", - ) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, + "The scaling parameter has not been correctly restored!", + ) def test_simple_inference_with_text_lora_denoiser_fused(self): """ @@ -767,16 +791,20 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config) else: pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -787,9 +815,14 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - pipe.fuse_lora() + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) + # Fusing should still keep the LoRA layers - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -799,9 +832,9 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images + output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse( - np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) def test_simple_inference_with_text_denoiser_lora_unloaded(self): @@ -819,15 +852,19 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config) else: pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -855,13 +892,15 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): "Lora not correctly unloaded in text encoder 2", ) - ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images + output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), + np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output", ) - def test_simple_inference_with_text_denoiser_lora_unfused(self): + def test_simple_inference_with_text_denoiser_lora_unfused( + self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + ): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected @@ -876,13 +915,17 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe.text_encoder.add_adapter(text_lora_config) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config) else: pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -893,15 +936,16 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - pipe.fuse_lora() - - output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) + output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.unfuse_lora() + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images # unloading should remove the LoRA layers - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers") @@ -913,8 +957,8 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): # Fuse and unfuse should lead to the same results self.assertTrue( - np.allclose(output_fused_lora, output_unfused_lora, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", + np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), + "Fused lora should not change the output", ) def test_simple_inference_with_text_denoiser_multi_adapter(self): @@ -932,10 +976,14 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") @@ -946,7 +994,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -959,15 +1006,13 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): ) pipe.set_adapters("adapter-1") - - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters("adapter-2") - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters(["adapter-1", "adapter-2"]) - - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] # Fuse and unfuse should lead to the same results self.assertFalse( @@ -986,8 +1031,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): ) pipe.disable_lora() - - output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images + output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), @@ -999,7 +1043,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self): Tests a simple inference with lora attached to text encoder and unet, attaches one adapter and set differnt weights for different blocks (i.e. block lora) """ - if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "CogVideoXPipeline"]: return scheduler_classes = ( @@ -1012,7 +1056,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") if self.unet_kwargs is not None: @@ -1033,11 +1077,11 @@ def test_simple_inference_with_text_denoiser_block_scale(self): weights_1 = {"text_encoder": 2, "unet": {"down": 5}} pipe.set_adapters("adapter-1", weights_1) - output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] weights_2 = {"unet": {"up": 5}} pipe.set_adapters("adapter-1", weights_2) - output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse( np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), @@ -1053,7 +1097,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self): ) pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images + output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), @@ -1078,10 +1122,14 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") @@ -1092,7 +1140,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -1106,16 +1153,15 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): scales_1 = {"text_encoder": 2, "unet": {"down": 5}} scales_2 = {"unet": {"down": 5, "mid": 5}} - pipe.set_adapters("adapter-1", scales_1) - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + pipe.set_adapters("adapter-1", scales_1) + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters("adapter-2", scales_2) - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) - - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] # Fuse and unfuse should lead to the same results self.assertFalse( @@ -1134,8 +1180,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): ) pipe.disable_lora() - - output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images + output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), @@ -1148,7 +1193,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" - if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]: + if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]: return def updown_options(blocks_with_tf, layers_per_block, value): @@ -1253,21 +1298,25 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -1281,15 +1330,13 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): ) pipe.set_adapters("adapter-1") - - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters("adapter-2") - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters(["adapter-1", "adapter-2"]) - - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse( np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), @@ -1307,7 +1354,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): ) pipe.delete_adapters("adapter-1") - output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), @@ -1315,15 +1362,16 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): ) pipe.delete_adapters("adapter-2") - output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images + output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), "output with no lora and output with lora disabled should give same results", ) - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") @@ -1337,7 +1385,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"]) - output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images + output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), @@ -1359,10 +1407,14 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") @@ -1373,7 +1425,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -1387,15 +1438,13 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): ) pipe.set_adapters("adapter-1") - - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters("adapter-2") - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters(["adapter-1", "adapter-2"]) - - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] # Fuse and unfuse should lead to the same results self.assertFalse( @@ -1414,7 +1463,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): ) pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) - output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0)).images + output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse( np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), @@ -1423,7 +1472,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images + output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), @@ -1442,14 +1491,17 @@ def test_lora_fuse_nan(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -1464,12 +1516,12 @@ def test_lora_fuse_nan(self): # with `safe_fusing=True` we should see an Error with self.assertRaises(ValueError): - pipe.fuse_lora(safe_fusing=True) + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) # without we should not see an error, but every image will be black - pipe.fuse_lora(safe_fusing=False) + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) - out = pipe("test", num_inference_steps=2, output_type="np").images + out = pipe("test", num_inference_steps=2, output_type="np")[0] self.assertTrue(np.isnan(out).all()) @@ -1523,59 +1575,80 @@ def test_get_list_adapters(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + # 1. + dicts_to_be_checked = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + dicts_to_be_checked = {"text_encoder": ["adapter-1"]} + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - adapter_names = pipe.get_list_adapters() - dicts_to_be_checked = {"text_encoder": ["adapter-1"]} if self.unet_kwargs is not None: dicts_to_be_checked.update({"unet": ["adapter-1"]}) else: dicts_to_be_checked.update({"transformer": ["adapter-1"]}) - self.assertDictEqual(adapter_names, dicts_to_be_checked) - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) + + # 2. + dicts_to_be_checked = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - adapter_names = pipe.get_list_adapters() - dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} if self.unet_kwargs is not None: dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) else: dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) - self.assertDictEqual(adapter_names, dicts_to_be_checked) + self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) + + # 3. pipe.set_adapters(["adapter-1", "adapter-2"]) - dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + + dicts_to_be_checked = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + if self.unet_kwargs is not None: dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) else: dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) + self.assertDictEqual( pipe.get_list_adapters(), dicts_to_be_checked, ) + # 4. if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") - dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + dicts_to_be_checked = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} + if self.unet_kwargs is not None: dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]}) else: dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]}) + self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) @require_peft_version_greater(peft_version="0.6.2") - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + def test_simple_inference_with_text_lora_denoiser_fused_multi( + self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + ): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet and multi-adapter case @@ -1590,23 +1663,29 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") # Attach a second adapter - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -1621,28 +1700,30 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): # set them to multi-adapter inference mode pipe.set_adapters(["adapter-1", "adapter-2"]) - ouputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.set_adapters(["adapter-1"]) - ouputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.fuse_lora(adapter_names=["adapter-1"]) + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) # Fusing should still keep the LoRA layers so outpout should remain the same - outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0)).images + outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=1e-3, rtol=1e-3), + np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), "Fused lora should not change the output", ) - pipe.unfuse_lora() - pipe.fuse_lora(adapter_names=["adapter-2", "adapter-1"]) + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + pipe.fuse_lora( + components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"] + ) # Fusing should still keep the LoRA layers - output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0)).images + output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - np.allclose(output_all_lora_fused, ouputs_all_lora, atol=1e-3, rtol=1e-3), + np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), "Fused lora should not change the output", ) @@ -1660,7 +1741,7 @@ def test_simple_inference_with_dora(self): pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_dora_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) @@ -1681,7 +1762,7 @@ def test_simple_inference_with_dora(self): check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse( np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), @@ -1727,10 +1808,10 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) # Just makes sure it works.. - _ = pipe(**inputs, generator=torch.manual_seed(0)).images + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] def test_modify_padding_mode(self): - if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]: + if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]: return def set_pad_mode(network, mode="circular"): @@ -1751,4 +1832,4 @@ def set_pad_mode(network, mode="circular"): set_pad_mode(pipe.unet, _pad_mode) _, _, inputs = self.get_dummy_inputs() - _ = pipe(**inputs).images + _ = pipe(**inputs)[0]