Skip to content

Commit

Permalink
Handle steps and guidance for each unet stage
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed Apr 30, 2024
1 parent 3c5474d commit 2ab0fce
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 40 deletions.
3 changes: 3 additions & 0 deletions OnnxStack.StableDiffusion/Config/SchedulerOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ public record SchedulerOptions

public float ConditioningScale { get; set; } = 0.7f;

public int InferenceSteps2 { get; set; } = 10;
public float GuidanceScale2 { get; set; } = 0;

public bool IsKarrasScheduler
{
get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using OnnxStack.StableDiffusion.Models;
using OnnxStack.StableDiffusion.Schedulers.StableDiffusion;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
Expand All @@ -17,6 +18,9 @@ namespace OnnxStack.StableDiffusion.Diffusers.StableCascade
{
public abstract class StableCascadeDiffuser : DiffuserBase
{
private readonly float _latentDimScale;
private readonly float _resolutionMultiple;
private readonly int _clipImageChannels;
private readonly UNetConditionModel _decoderUnet;

/// <summary>
Expand All @@ -32,6 +36,9 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
: base(priorUnet, decoderVqgan, imageEncoder, memoryMode, logger)
{
_decoderUnet = decoderUnet;
_latentDimScale = 10.67f;
_resolutionMultiple = 42.67f;
_clipImageChannels = 768;
}

/// <summary>
Expand All @@ -40,6 +47,32 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.StableCascade;


/// <summary>
/// Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
/// height=24 and width = 24, the VQ latent shape needs to be height=int (24*10.67)=256 and
/// width = int(24 * 10.67) = 256 in order to match the training conditions.
/// </summary>
protected float LatentDimScale => _latentDimScale;


/// <summary>
/// Default resolution for multiple images generated
/// </summary>
protected float ResolutionMultiple => _resolutionMultiple;


/// <summary>
/// Prepares the decoder latents.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <param name="timesteps">The timesteps.</param>
/// <param name="priorLatents">The prior latents.</param>
/// <returns></returns>
protected abstract Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents);


/// <summary>
/// Runs the scheduler steps.
/// </summary>
Expand All @@ -52,27 +85,55 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
/// <returns></returns>
public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
var decodeSchedulerOptions = schedulerOptions with
{
InferenceSteps = schedulerOptions.InferenceSteps2,
GuidanceScale = schedulerOptions.GuidanceScale2
};

var priorPromptEmbeddings = promptEmbeddings;
var decoderPromptEmbeddings = promptEmbeddings;
var priorPerformGuidance = schedulerOptions.GuidanceScale > 0;
var decoderPerformGuidance = decodeSchedulerOptions.GuidanceScale > 0;
if (performGuidance)
{
if (!priorPerformGuidance)
priorPromptEmbeddings = SplitPromptEmbeddings(promptEmbeddings);
if (!decoderPerformGuidance)
decoderPromptEmbeddings = SplitPromptEmbeddings(promptEmbeddings);
}

// Prior Unet
var latentsPrior = await DiffusePriorAsync(schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
var priorLatents = await DiffusePriorAsync(promptOptions, schedulerOptions, priorPromptEmbeddings, priorPerformGuidance, progressCallback, cancellationToken);

// Decoder Unet
var schedulerOptionsDecoder = schedulerOptions with { InferenceSteps = 10, GuidanceScale = 0 };
var latents = await DiffuseDecodeAsync(latentsPrior, schedulerOptionsDecoder, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
var decoderLatents = await DiffuseDecodeAsync(promptOptions, priorLatents, decodeSchedulerOptions, decoderPromptEmbeddings, decoderPerformGuidance, progressCallback, cancellationToken);

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
return await DecodeLatentsAsync(promptOptions, schedulerOptions, decoderLatents);
}


protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)

/// <summary>
/// Run the Prior UNET diffusion
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="promptEmbeddings">The prompt embeddings.</param>
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
protected async Task<DenseTensor<float>> DiffusePriorAsync(PromptOptions prompt, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
using (var scheduler = GetScheduler(schedulerOptions))
{
// Get timesteps
var timesteps = GetTimesteps(schedulerOptions, scheduler);

// Create latent sample
var latents = scheduler.CreateRandomSample(new[] { 1, 16, (int)Math.Ceiling(schedulerOptions.Height / 42.67f), (int)Math.Ceiling(schedulerOptions.Width / 42.67f) }, scheduler.InitNoiseSigma);
var latents = await PrepareLatentsAsync(prompt, schedulerOptions, scheduler, timesteps);

// Get Model metadata
var metadata = await _unet.GetMetadataAsync();
Expand All @@ -89,18 +150,15 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions sche
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
var imageEmbeds = new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, 768 });

var outputChannels = performGuidance ? 2 : 1;
var outputDimension = inputTensor.Dimensions.ToArray();
var imageEmbeds = new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, _clipImageChannels });
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddInputTensor(timestepTensor);
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
inferenceParameters.AddInputTensor(imageEmbeds);
inferenceParameters.AddOutputBuffer(outputDimension);
inferenceParameters.AddOutputBuffer(inputTensor.Dimensions);

var results = await _unet.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
Expand Down Expand Up @@ -129,23 +187,33 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions sche
}


protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> latentsPrior, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
/// <summary>
/// Run the Decoder UNET diffusion
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="priorLatents">The prior latents.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="promptEmbeddings">The prompt embeddings.</param>
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
protected async Task<DenseTensor<float>> DiffuseDecodeAsync(PromptOptions prompt, DenseTensor<float> priorLatents, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
using (var scheduler = GetScheduler(schedulerOptions))
{
// Get timesteps
var timesteps = GetTimesteps(schedulerOptions, scheduler);

// Create latent sample
var latents = scheduler.CreateRandomSample(new[] { 1, 4, (int)(latentsPrior.Dimensions[2] * 10.67f), (int)(latentsPrior.Dimensions[3] * 10.67f) }, scheduler.InitNoiseSigma);
var latents = await PrepareDecoderLatentsAsync(prompt, schedulerOptions, scheduler, timesteps, priorLatents);

// Get Model metadata
var metadata = await _decoderUnet.GetMetadataAsync();

var effnet = performGuidance
? latentsPrior
: latentsPrior.Concatenate(new DenseTensor<float>(latentsPrior.Dimensions));

var effnet = !performGuidance
? priorLatents
: priorLatents.Repeat(2);

// Loop though the timesteps
var step = 0;
Expand All @@ -159,18 +227,15 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> l
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);

var outputChannels = performGuidance ? 2 : 1;
var outputDimension = inputTensor.Dimensions.ToArray(); //schedulerOptions.GetScaledDimension(outputChannels);
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddInputTensor(timestepTensor);
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
inferenceParameters.AddInputTensor(effnet);
inferenceParameters.AddOutputBuffer();
inferenceParameters.AddOutputBuffer(inputTensor.Dimensions);

var results = _decoderUnet.RunInference(inferenceParameters);
var results = await _decoderUnet.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
var noisePred = result.ToDenseTensor();
Expand All @@ -197,6 +262,13 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> l
}


/// <summary>
/// Decodes the latents.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="latents">The latents.</param>
/// <returns></returns>
protected override async Task<DenseTensor<float>> DecodeLatentsAsync(PromptOptions prompt, SchedulerOptions options, DenseTensor<float> latents)
{
latents = latents.MultiplyBy(_vaeDecoder.ScaleFactor);
Expand Down Expand Up @@ -239,6 +311,19 @@ private DenseTensor<float> CreateTimestepTensor(DenseTensor<float> latents, int
}


/// <summary>
/// Splits the prompt embeddings, Removes unconditional embeddings
/// </summary>
/// <param name="promptEmbeddings">The prompt embeddings.</param>
/// <returns></returns>
private PromptEmbeddingsResult SplitPromptEmbeddings(PromptEmbeddingsResult promptEmbeddings)
{
return promptEmbeddings.PooledPromptEmbeds is null
? new PromptEmbeddingsResult(promptEmbeddings.PromptEmbeds.SplitBatch().Last())
: new PromptEmbeddingsResult(promptEmbeddings.PromptEmbeds.SplitBatch().Last(), promptEmbeddings.PooledPromptEmbeds.SplitBatch().Last());
}


/// <summary>
/// Gets the scheduler.
/// </summary>
Expand Down
28 changes: 20 additions & 8 deletions OnnxStack.StableDiffusion/Diffusers/StableCascade/TextDiffuser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Models;
using System;
using System.Collections.Generic;
using System.Threading.Tasks;

Expand Down Expand Up @@ -47,16 +48,27 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
}


/// <summary>
/// Prepares the latents for inference.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <returns></returns>
protected override Task<DenseTensor<float>> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
var latents = scheduler.CreateRandomSample(new[]
{
1, 16,
(int)Math.Ceiling(options.Height / ResolutionMultiple),
(int)Math.Ceiling(options.Width / ResolutionMultiple)
}, scheduler.InitNoiseSigma);
return Task.FromResult(latents);
}


protected override Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents)
{
var latents = scheduler.CreateRandomSample(new[]
{
1, 4,
(int)(priorLatents.Dimensions[2] * LatentDimScale),
(int)(priorLatents.Dimensions[3] * LatentDimScale)
}, scheduler.InitNoiseSigma);
return Task.FromResult(latents);
}
}
}
11 changes: 11 additions & 0 deletions OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
}


/// <summary>
/// Check if we should run guidance.
/// </summary>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <returns></returns>
protected override bool ShouldPerformGuidance(SchedulerOptions schedulerOptions)
{
return schedulerOptions.GuidanceScale > 1f || schedulerOptions.GuidanceScale2 > 1f;
}


/// <summary>
/// Creates the embeds using Tokenizer2 and TextEncoder2
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ internal class DDPMWuerstchenScheduler : SchedulerBase
private float _s;
private float _scaler;
private float _initAlphaCumprod;
private float _timestepRatio = 1000f;


/// <summary>
Expand Down Expand Up @@ -47,14 +48,12 @@ protected override void Initialize()
/// <returns></returns>
protected override int[] SetTimesteps()
{
// Create timesteps based on the specified strategy
var timesteps = ArrayHelpers.Linspace(0, 1000, Options.InferenceSteps + 1);
var x = timesteps
var timesteps = ArrayHelpers.Linspace(0, _timestepRatio, Options.InferenceSteps + 1);
return timesteps
.Skip(1)
.Select(x => (int)x)
.OrderByDescending(x => x)
.ToArray();
return x;
}


Expand Down Expand Up @@ -82,8 +81,8 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
/// <exception cref="NotImplementedException">DDPMScheduler Thresholding currently not implemented</exception>
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
{
var currentTimestep = timestep / 1000f;
var previousTimestep = GetPreviousTimestep(timestep) / 1000f;
var currentTimestep = timestep / _timestepRatio;
var previousTimestep = GetPreviousTimestep(timestep) / _timestepRatio;

var alpha_cumprod = GetAlphaCumprod(currentTimestep);
var alpha_cumprod_prev = GetAlphaCumprod(previousTimestep);
Expand All @@ -108,10 +107,10 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
/// <returns></returns>
public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples, DenseTensor<float> noise, IReadOnlyList<int> timesteps)
{
float timestep = timesteps[0] / 1000f;
float alphaProd = GetAlphaCumprod(timestep);
float sqrtAlpha = MathF.Sqrt(alphaProd);
float sqrtOneMinusAlpha = MathF.Sqrt(1.0f - alphaProd);
var timestep = timesteps[0] / _timestepRatio;
var alphaProd = GetAlphaCumprod(timestep);
var sqrtAlpha = MathF.Sqrt(alphaProd);
var sqrtOneMinusAlpha = MathF.Sqrt(1.0f - alphaProd);

return noise
.MultiplyTensorByFloat(sqrtOneMinusAlpha)
Expand Down

0 comments on commit 2ab0fce

Please sign in to comment.