Skip to content

sayakpaul/stable-diffusion-keras-ft

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

35 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fine-tuning Stable Diffusion using Keras

This repository provides code for fine-tuning Stable Diffusion in Keras. It is adapted from this script by Hugging Face. The pre-trained model used for fine-tuning comes from KerasCV. To know about the original model check out this documentation.

The code provided in this repository is for research purposes only. Please check out this section to know more about the potential use cases and limitations.

By loading this model you accept the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE.

If you're just looking for the accompanying resources of this repository, here are the links:

Table of contents:

This repository has a sister repository (keras-sd-serving) that covers various deployment patterns for Stable Diffusion.

Update January 13 2023: This project secured 2nd place at the first-ever Keras Community Prize Competition organized by Google.

Dataset

Following the original script from Hugging Face, this repository also uses the Pokemon dataset. But it was regenerated to work better with tf.data. The regenerated version of the dataset is hosted here. Check out that link for more details.

Training

Fine-tuning code is provided in finetune.py. Before running training, ensure you have the dependencies (refer to requirements.txt) installed.

You can launch training with the default arguments by running python finetune.py. Run python finetune.py -h to know about the supported command-line arguments. You can enable mixed-precision training by passing the --mp flag.

When you launch training, a diffusion model checkpoint will be generated epoch-wise only if the current loss is lower than the previous one.

For avoiding OOM and faster training, it's recommended to use a V100 GPU at least. We used an A100.

Some important details to note:

  • Distributed training is not yet supported. Gradient accumulation and gradient checkpointing are also not supported.
  • Only the diffusion model is fine-tuned. The VAE and the text encoder are kept frozen.

Training details:

We fine-tuned the model on two different resolutions: 256x256 and 512x512. We only varied the batch size and number of epochs for fine-tuning with these two different resolutions. Since we didn't use gradient accumulation, we use this code snippet to derive the number of epochs.

  • 256x256: python finetune.py --batch_size 4 --num_epochs 577
  • 512x512: python finetune.py --img_height 512 --img_width 512 --batch_size 1 --num_epochs 72 --mp

For 256x256 resolution, we intentionally reduced the number of epochs to save compute time.

Fine-tuned weights:

You can find the fine-tuned diffusion model weights here.

Training with custom data

The default Pokemon dataset used in this repository comes with the following structure:

pokemon_dataset/
    data.csv
    image_24.png   
    image_3.png    
    image_550.png  
    image_700.png
    ...

data.csv looks like so:

As long as your custom dataset follows this structure, you don't need to change anything in the current codebase except for the dataset_archive.

In case your dataset has multiple captions per image, you can randomly select one from the pool of captions per image during training.

Based on the dataset, you might have to tune the hyperparameters.

Inference

import keras_cv
import matplotlib.pyplot as plt
from tensorflow import keras

IMG_HEIGHT = IMG_WIDTH = 512


def plot_images(images, title):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.title(title)
        plt.imshow(images[i])
        plt.axis("off")


# We just have to load the fine-tuned weights into the diffusion model.
weights_path = keras.utils.get_file(
    origin="https://huggingface.co/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"
)
pokemon_model = keras_cv.models.StableDiffusion(
    img_height=IMG_HEIGHT, img_width=IMG_WIDTH
)
pokemon_model.diffusion_model.load_weights(weights_path)

# Generate images.
generated_images = pokemon_model.text_to_image("Yoda", batch_size=3)
plot_images(generated_images, "Fine-tuned on the Pokemon dataset")

You can bring in your weights_path (should be compatible with the diffusion_model) and reuse the code snippet.

Check out this Colab Notebook to play with the inference code.

Results

Initially, we fine-tuned the model on a resolution of 256x256. Here are some results along with comparisons to the results of the original model.

Images Prompts
Yoda
robotic cat with wings
Hello Kitty
Weights

We can see that the fine-tuned model has more stable outputs than the original model. Even though the results can be aesthetically improved much more, the fine-tuning effects are visible. Also, we followed the same hyperparameters from Hugging Face's script for the 256x256 resolution (apart from number of epochs and batch size). With better hyperparameters, the results will likely improve.

For the 512x512 resolution, we observe something similar. So, we experimented with the unconditional_guidance_scale parameter and noticed that when it's set to 40 (while keeping the other arguments fixed), the results came out better.

Images Prompts
Yoda
robotic cat with wings
Hello Kitty
Weights

Note: Fine-tuning on the 512x512 is still in progress as of this writing. But it takes a lot of time to complete a single epoch without the presence of distributed training and gradient accumulation. The above results are from the checkpoint derived after 60th epoch.

With a similar recipe (but trained for more optimization steps), Lambda Labs demonstrate amazing results.

Acknowledgements

  • Thanks to Hugging Face for providing the fine-tuning script. It's very readable and easy to understand.
  • Thanks to the ML Developer Programs' team at Google for providing GCP credits.

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •