Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 2ab0fce

Browse files
committed
Handle steps and guidance for each unet stage
1 parent 3c5474d commit 2ab0fce

File tree

5 files changed

+150
-40
lines changed

5 files changed

+150
-40
lines changed

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ public record SchedulerOptions
8787

8888
public float ConditioningScale { get; set; } = 0.7f;
8989

90+
public int InferenceSteps2 { get; set; } = 10;
91+
public float GuidanceScale2 { get; set; } = 0;
92+
9093
public bool IsKarrasScheduler
9194
{
9295
get

OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs

Lines changed: 107 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using OnnxStack.StableDiffusion.Models;
99
using OnnxStack.StableDiffusion.Schedulers.StableDiffusion;
1010
using System;
11+
using System.Collections.Generic;
1112
using System.Diagnostics;
1213
using System.Linq;
1314
using System.Threading;
@@ -17,6 +18,9 @@ namespace OnnxStack.StableDiffusion.Diffusers.StableCascade
1718
{
1819
public abstract class StableCascadeDiffuser : DiffuserBase
1920
{
21+
private readonly float _latentDimScale;
22+
private readonly float _resolutionMultiple;
23+
private readonly int _clipImageChannels;
2024
private readonly UNetConditionModel _decoderUnet;
2125

2226
/// <summary>
@@ -32,6 +36,9 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
3236
: base(priorUnet, decoderVqgan, imageEncoder, memoryMode, logger)
3337
{
3438
_decoderUnet = decoderUnet;
39+
_latentDimScale = 10.67f;
40+
_resolutionMultiple = 42.67f;
41+
_clipImageChannels = 768;
3542
}
3643

3744
/// <summary>
@@ -40,6 +47,32 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
4047
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.StableCascade;
4148

4249

50+
/// <summary>
51+
/// Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
52+
/// height=24 and width = 24, the VQ latent shape needs to be height=int (24*10.67)=256 and
53+
/// width = int(24 * 10.67) = 256 in order to match the training conditions.
54+
/// </summary>
55+
protected float LatentDimScale => _latentDimScale;
56+
57+
58+
/// <summary>
59+
/// Default resolution for multiple images generated
60+
/// </summary>
61+
protected float ResolutionMultiple => _resolutionMultiple;
62+
63+
64+
/// <summary>
65+
/// Prepares the decoder latents.
66+
/// </summary>
67+
/// <param name="prompt">The prompt.</param>
68+
/// <param name="options">The options.</param>
69+
/// <param name="scheduler">The scheduler.</param>
70+
/// <param name="timesteps">The timesteps.</param>
71+
/// <param name="priorLatents">The prior latents.</param>
72+
/// <returns></returns>
73+
protected abstract Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents);
74+
75+
4376
/// <summary>
4477
/// Runs the scheduler steps.
4578
/// </summary>
@@ -52,27 +85,55 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
5285
/// <returns></returns>
5386
public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
5487
{
88+
var decodeSchedulerOptions = schedulerOptions with
89+
{
90+
InferenceSteps = schedulerOptions.InferenceSteps2,
91+
GuidanceScale = schedulerOptions.GuidanceScale2
92+
};
93+
94+
var priorPromptEmbeddings = promptEmbeddings;
95+
var decoderPromptEmbeddings = promptEmbeddings;
96+
var priorPerformGuidance = schedulerOptions.GuidanceScale > 0;
97+
var decoderPerformGuidance = decodeSchedulerOptions.GuidanceScale > 0;
98+
if (performGuidance)
99+
{
100+
if (!priorPerformGuidance)
101+
priorPromptEmbeddings = SplitPromptEmbeddings(promptEmbeddings);
102+
if (!decoderPerformGuidance)
103+
decoderPromptEmbeddings = SplitPromptEmbeddings(promptEmbeddings);
104+
}
105+
55106
// Prior Unet
56-
var latentsPrior = await DiffusePriorAsync(schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
107+
var priorLatents = await DiffusePriorAsync(promptOptions, schedulerOptions, priorPromptEmbeddings, priorPerformGuidance, progressCallback, cancellationToken);
57108

58109
// Decoder Unet
59-
var schedulerOptionsDecoder = schedulerOptions with { InferenceSteps = 10, GuidanceScale = 0 };
60-
var latents = await DiffuseDecodeAsync(latentsPrior, schedulerOptionsDecoder, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
110+
var decoderLatents = await DiffuseDecodeAsync(promptOptions, priorLatents, decodeSchedulerOptions, decoderPromptEmbeddings, decoderPerformGuidance, progressCallback, cancellationToken);
61111

62112
// Decode Latents
63-
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
113+
return await DecodeLatentsAsync(promptOptions, schedulerOptions, decoderLatents);
64114
}
65115

66116

67-
protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
117+
118+
/// <summary>
119+
/// Run the Prior UNET diffusion
120+
/// </summary>
121+
/// <param name="prompt">The prompt.</param>
122+
/// <param name="schedulerOptions">The scheduler options.</param>
123+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
124+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
125+
/// <param name="progressCallback">The progress callback.</param>
126+
/// <param name="cancellationToken">The cancellation token.</param>
127+
/// <returns></returns>
128+
protected async Task<DenseTensor<float>> DiffusePriorAsync(PromptOptions prompt, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
68129
{
69130
using (var scheduler = GetScheduler(schedulerOptions))
70131
{
71132
// Get timesteps
72133
var timesteps = GetTimesteps(schedulerOptions, scheduler);
73134

74135
// Create latent sample
75-
var latents = scheduler.CreateRandomSample(new[] { 1, 16, (int)Math.Ceiling(schedulerOptions.Height / 42.67f), (int)Math.Ceiling(schedulerOptions.Width / 42.67f) }, scheduler.InitNoiseSigma);
136+
var latents = await PrepareLatentsAsync(prompt, schedulerOptions, scheduler, timesteps);
76137

77138
// Get Model metadata
78139
var metadata = await _unet.GetMetadataAsync();
@@ -89,18 +150,15 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions sche
89150
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
90151
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
91152
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
92-
var imageEmbeds = new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, 768 });
93-
94-
var outputChannels = performGuidance ? 2 : 1;
95-
var outputDimension = inputTensor.Dimensions.ToArray();
153+
var imageEmbeds = new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, _clipImageChannels });
96154
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
97155
{
98156
inferenceParameters.AddInputTensor(inputTensor);
99157
inferenceParameters.AddInputTensor(timestepTensor);
100158
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
101159
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
102160
inferenceParameters.AddInputTensor(imageEmbeds);
103-
inferenceParameters.AddOutputBuffer(outputDimension);
161+
inferenceParameters.AddOutputBuffer(inputTensor.Dimensions);
104162

105163
var results = await _unet.RunInferenceAsync(inferenceParameters);
106164
using (var result = results.First())
@@ -129,23 +187,33 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions sche
129187
}
130188

131189

132-
protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> latentsPrior, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
190+
/// <summary>
191+
/// Run the Decoder UNET diffusion
192+
/// </summary>
193+
/// <param name="prompt">The prompt.</param>
194+
/// <param name="priorLatents">The prior latents.</param>
195+
/// <param name="schedulerOptions">The scheduler options.</param>
196+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
197+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
198+
/// <param name="progressCallback">The progress callback.</param>
199+
/// <param name="cancellationToken">The cancellation token.</param>
200+
/// <returns></returns>
201+
protected async Task<DenseTensor<float>> DiffuseDecodeAsync(PromptOptions prompt, DenseTensor<float> priorLatents, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
133202
{
134203
using (var scheduler = GetScheduler(schedulerOptions))
135204
{
136205
// Get timesteps
137206
var timesteps = GetTimesteps(schedulerOptions, scheduler);
138207

139208
// Create latent sample
140-
var latents = scheduler.CreateRandomSample(new[] { 1, 4, (int)(latentsPrior.Dimensions[2] * 10.67f), (int)(latentsPrior.Dimensions[3] * 10.67f) }, scheduler.InitNoiseSigma);
209+
var latents = await PrepareDecoderLatentsAsync(prompt, schedulerOptions, scheduler, timesteps, priorLatents);
141210

142211
// Get Model metadata
143212
var metadata = await _decoderUnet.GetMetadataAsync();
144213

145-
var effnet = performGuidance
146-
? latentsPrior
147-
: latentsPrior.Concatenate(new DenseTensor<float>(latentsPrior.Dimensions));
148-
214+
var effnet = !performGuidance
215+
? priorLatents
216+
: priorLatents.Repeat(2);
149217

150218
// Loop though the timesteps
151219
var step = 0;
@@ -159,18 +227,15 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> l
159227
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
160228
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
161229
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
162-
163-
var outputChannels = performGuidance ? 2 : 1;
164-
var outputDimension = inputTensor.Dimensions.ToArray(); //schedulerOptions.GetScaledDimension(outputChannels);
165230
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
166231
{
167232
inferenceParameters.AddInputTensor(inputTensor);
168233
inferenceParameters.AddInputTensor(timestepTensor);
169234
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
170235
inferenceParameters.AddInputTensor(effnet);
171-
inferenceParameters.AddOutputBuffer();
236+
inferenceParameters.AddOutputBuffer(inputTensor.Dimensions);
172237

173-
var results = _decoderUnet.RunInference(inferenceParameters);
238+
var results = await _decoderUnet.RunInferenceAsync(inferenceParameters);
174239
using (var result = results.First())
175240
{
176241
var noisePred = result.ToDenseTensor();
@@ -197,6 +262,13 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> l
197262
}
198263

199264

265+
/// <summary>
266+
/// Decodes the latents.
267+
/// </summary>
268+
/// <param name="prompt">The prompt.</param>
269+
/// <param name="options">The options.</param>
270+
/// <param name="latents">The latents.</param>
271+
/// <returns></returns>
200272
protected override async Task<DenseTensor<float>> DecodeLatentsAsync(PromptOptions prompt, SchedulerOptions options, DenseTensor<float> latents)
201273
{
202274
latents = latents.MultiplyBy(_vaeDecoder.ScaleFactor);
@@ -239,6 +311,19 @@ private DenseTensor<float> CreateTimestepTensor(DenseTensor<float> latents, int
239311
}
240312

241313

314+
/// <summary>
315+
/// Splits the prompt embeddings, Removes unconditional embeddings
316+
/// </summary>
317+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
318+
/// <returns></returns>
319+
private PromptEmbeddingsResult SplitPromptEmbeddings(PromptEmbeddingsResult promptEmbeddings)
320+
{
321+
return promptEmbeddings.PooledPromptEmbeds is null
322+
? new PromptEmbeddingsResult(promptEmbeddings.PromptEmbeds.SplitBatch().Last())
323+
: new PromptEmbeddingsResult(promptEmbeddings.PromptEmbeds.SplitBatch().Last(), promptEmbeddings.PooledPromptEmbeds.SplitBatch().Last());
324+
}
325+
326+
242327
/// <summary>
243328
/// Gets the scheduler.
244329
/// </summary>

OnnxStack.StableDiffusion/Diffusers/StableCascade/TextDiffuser.cs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using OnnxStack.StableDiffusion.Config;
66
using OnnxStack.StableDiffusion.Enums;
77
using OnnxStack.StableDiffusion.Models;
8+
using System;
89
using System.Collections.Generic;
910
using System.Threading.Tasks;
1011

@@ -47,16 +48,27 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
4748
}
4849

4950

50-
/// <summary>
51-
/// Prepares the latents for inference.
52-
/// </summary>
53-
/// <param name="prompt">The prompt.</param>
54-
/// <param name="options">The options.</param>
55-
/// <param name="scheduler">The scheduler.</param>
56-
/// <returns></returns>
5751
protected override Task<DenseTensor<float>> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
5852
{
59-
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
53+
var latents = scheduler.CreateRandomSample(new[]
54+
{
55+
1, 16,
56+
(int)Math.Ceiling(options.Height / ResolutionMultiple),
57+
(int)Math.Ceiling(options.Width / ResolutionMultiple)
58+
}, scheduler.InitNoiseSigma);
59+
return Task.FromResult(latents);
60+
}
61+
62+
63+
protected override Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents)
64+
{
65+
var latents = scheduler.CreateRandomSample(new[]
66+
{
67+
1, 4,
68+
(int)(priorLatents.Dimensions[2] * LatentDimScale),
69+
(int)(priorLatents.Dimensions[3] * LatentDimScale)
70+
}, scheduler.InitNoiseSigma);
71+
return Task.FromResult(latents);
6072
}
6173
}
6274
}

OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,17 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
104104
}
105105

106106

107+
/// <summary>
108+
/// Check if we should run guidance.
109+
/// </summary>
110+
/// <param name="schedulerOptions">The scheduler options.</param>
111+
/// <returns></returns>
112+
protected override bool ShouldPerformGuidance(SchedulerOptions schedulerOptions)
113+
{
114+
return schedulerOptions.GuidanceScale > 1f || schedulerOptions.GuidanceScale2 > 1f;
115+
}
116+
117+
107118
/// <summary>
108119
/// Creates the embeds using Tokenizer2 and TextEncoder2
109120
/// </summary>

OnnxStack.StableDiffusion/Schedulers/StableDiffusion/DDPMWuerstchenScheduler.cs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ internal class DDPMWuerstchenScheduler : SchedulerBase
1313
private float _s;
1414
private float _scaler;
1515
private float _initAlphaCumprod;
16+
private float _timestepRatio = 1000f;
1617

1718

1819
/// <summary>
@@ -47,14 +48,12 @@ protected override void Initialize()
4748
/// <returns></returns>
4849
protected override int[] SetTimesteps()
4950
{
50-
// Create timesteps based on the specified strategy
51-
var timesteps = ArrayHelpers.Linspace(0, 1000, Options.InferenceSteps + 1);
52-
var x = timesteps
51+
var timesteps = ArrayHelpers.Linspace(0, _timestepRatio, Options.InferenceSteps + 1);
52+
return timesteps
5353
.Skip(1)
5454
.Select(x => (int)x)
5555
.OrderByDescending(x => x)
5656
.ToArray();
57-
return x;
5857
}
5958

6059

@@ -82,8 +81,8 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
8281
/// <exception cref="NotImplementedException">DDPMScheduler Thresholding currently not implemented</exception>
8382
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
8483
{
85-
var currentTimestep = timestep / 1000f;
86-
var previousTimestep = GetPreviousTimestep(timestep) / 1000f;
84+
var currentTimestep = timestep / _timestepRatio;
85+
var previousTimestep = GetPreviousTimestep(timestep) / _timestepRatio;
8786

8887
var alpha_cumprod = GetAlphaCumprod(currentTimestep);
8988
var alpha_cumprod_prev = GetAlphaCumprod(previousTimestep);
@@ -108,10 +107,10 @@ public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int tim
108107
/// <returns></returns>
109108
public override DenseTensor<float> AddNoise(DenseTensor<float> originalSamples, DenseTensor<float> noise, IReadOnlyList<int> timesteps)
110109
{
111-
float timestep = timesteps[0] / 1000f;
112-
float alphaProd = GetAlphaCumprod(timestep);
113-
float sqrtAlpha = MathF.Sqrt(alphaProd);
114-
float sqrtOneMinusAlpha = MathF.Sqrt(1.0f - alphaProd);
110+
var timestep = timesteps[0] / _timestepRatio;
111+
var alphaProd = GetAlphaCumprod(timestep);
112+
var sqrtAlpha = MathF.Sqrt(alphaProd);
113+
var sqrtOneMinusAlpha = MathF.Sqrt(1.0f - alphaProd);
115114

116115
return noise
117116
.MultiplyTensorByFloat(sqrtOneMinusAlpha)

0 commit comments

Comments
 (0)