Skip to content

Commit

Permalink
Utilize batches in Stable Diffusion (huggingface#2071)
Browse files Browse the repository at this point in the history
* Utilize batches in Stable Diffusion that were already there, but unutilized.

Also refactor out the `save_image` function.

* Clippy + cosmetic fixes.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
  • Loading branch information
NorilskMajor and LaurentMazare authored Apr 16, 2024
1 parent f135b79 commit 4d14777
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
3 changes: 2 additions & 1 deletion candle-examples/examples/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
- `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process.
- `--num-samples`: the number of samples to generate.
- `--num-samples`: the number of samples to generate iteratively.
- `--bsize`: the numbers of samples to generate simultaneously.
- `--final-image`: the filename for the generated image(s).

### Using flash-attention
Expand Down
74 changes: 58 additions & 16 deletions candle-examples/examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use stable_diffusion::vae::AutoEncoderKL;
use tokenizers::Tokenizer;

#[derive(Parser)]
Expand Down Expand Up @@ -64,9 +65,13 @@ struct Args {
#[arg(long)]
n_steps: Option<usize>,

/// The number of samples to generate.
/// The number of samples to generate iteratively.
#[arg(long, default_value_t = 1)]
num_samples: i64,
num_samples: usize,

/// The numbers of samples to generate simultaneously.
#[arg[long, default_value_t = 1]]
bsize: usize,

/// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
Expand Down Expand Up @@ -236,8 +241,8 @@ impl ModelFile {

fn output_filename(
basename: &str,
sample_idx: i64,
num_samples: i64,
sample_idx: usize,
num_samples: usize,
timestep_idx: Option<usize>,
) -> String {
let filename = if num_samples > 1 {
Expand All @@ -261,6 +266,33 @@ fn output_filename(
}
}

#[allow(clippy::too_many_arguments)]
fn save_image(
vae: &AutoEncoderKL,
latents: &Tensor,
vae_scale: f64,
bsize: usize,
idx: usize,
final_image: &str,
num_samples: usize,
timestep_ids: Option<usize>,
) -> Result<()> {
let images = vae.decode(&(latents / vae_scale)?)?;
let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
for batch in 0..bsize {
let image = images.i(batch)?;
let image_filename = output_filename(
final_image,
(bsize * idx) + batch + 1,
batch + num_samples,
timestep_ids,
);
candle_examples::save_image(&image, image_filename)?;
}
Ok(())
}

#[allow(clippy::too_many_arguments)]
fn text_embeddings(
prompt: &str,
Expand Down Expand Up @@ -382,6 +414,7 @@ fn run(args: Args) -> Result<()> {
final_image,
sliced_attention_size,
num_samples,
bsize,
sd_version,
clip_weights,
vae_weights,
Expand Down Expand Up @@ -475,6 +508,7 @@ fn run(args: Args) -> Result<()> {
.collect::<Result<Vec<_>>>()?;

let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?;
println!("{text_embeddings:?}");

println!("Building the autoencoder.");
Expand All @@ -496,7 +530,6 @@ fn run(args: Args) -> Result<()> {
} else {
0
};
let bsize = 1;

let vae_scale = match sd_version {
StableDiffusionVersion::V1_5
Expand Down Expand Up @@ -560,12 +593,16 @@ fn run(args: Args) -> Result<()> {
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);

if args.intermediary_images {
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
candle_examples::save_image(&image, image_filename)?
save_image(
&vae,
&latents,
vae_scale,
bsize,
idx,
&final_image,
num_samples,
Some(timestep_index + 1),
)?;
}
}

Expand All @@ -574,11 +611,16 @@ fn run(args: Args) -> Result<()> {
idx + 1,
num_samples
);
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)?
save_image(
&vae,
&latents,
vae_scale,
bsize,
idx,
&final_image,
num_samples,
None,
)?;
}
Ok(())
}
Expand Down

0 comments on commit 4d14777

Please sign in to comment.