Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 9b3acf4

Browse files
committed
Split Prior and Decoder logic
1 parent 8d5575a commit 9b3acf4

File tree

1 file changed

+53
-48
lines changed

1 file changed

+53
-48
lines changed

OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using Microsoft.Extensions.Logging;
22
using Microsoft.ML.OnnxRuntime.Tensors;
33
using OnnxStack.Core;
4-
using OnnxStack.Core.Image;
54
using OnnxStack.Core.Model;
65
using OnnxStack.StableDiffusion.Common;
76
using OnnxStack.StableDiffusion.Config;
@@ -53,40 +52,48 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
5352
/// <returns></returns>
5453
public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
5554
{
56-
// Get Scheduler
57-
using (var schedulerPrior = GetScheduler(schedulerOptions))
58-
using (var schedulerDecoder = GetScheduler(schedulerOptions with{ InferenceSteps = 10, GuidanceScale = 0}))
59-
{
60-
//----------------------------------------------------
61-
// Prior Unet
62-
//====================================================
55+
// Prior Unet
56+
var latentsPrior = await DiffusePriorAsync(schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
57+
58+
// Decoder Unet
59+
var schedulerOptionsDecoder = schedulerOptions with { InferenceSteps = 10, GuidanceScale = 0 };
60+
var latents = await DiffuseDecodeAsync(latentsPrior, schedulerOptionsDecoder, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
61+
62+
// Decode Latents
63+
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
64+
}
65+
6366

67+
protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
68+
{
69+
using (var scheduler = GetScheduler(schedulerOptions))
70+
{
6471
// Get timesteps
65-
var timestepsPrior = GetTimesteps(schedulerOptions, schedulerPrior);
72+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
6673

6774
// Create latent sample
68-
var latentsPrior = schedulerPrior.CreateRandomSample(new[] { 1, 16, (int)Math.Ceiling(schedulerOptions.Height / 42.67f), (int)Math.Ceiling(schedulerOptions.Width / 42.67f) }, schedulerPrior.InitNoiseSigma);
75+
var latents = scheduler.CreateRandomSample(new[] { 1, 16, (int)Math.Ceiling(schedulerOptions.Height / 42.67f), (int)Math.Ceiling(schedulerOptions.Width / 42.67f) }, scheduler.InitNoiseSigma);
6976

7077
// Get Model metadata
71-
var metadataPrior = await _unet.GetMetadataAsync();
78+
var metadata = await _unet.GetMetadataAsync();
7279

7380
// Loop though the timesteps
74-
var stepPrior = 0;
75-
foreach (var timestep in timestepsPrior)
81+
var step = 0;
82+
foreach (var timestep in timesteps)
7683
{
77-
stepPrior++;
84+
step++;
7885
var stepTime = Stopwatch.GetTimestamp();
7986
cancellationToken.ThrowIfCancellationRequested();
8087

8188
// Create input tensor.
82-
var inputLatent = performGuidance ? latentsPrior.Repeat(2) : latentsPrior;
83-
var inputTensor = schedulerPrior.ScaleInput(inputLatent, timestep);
84-
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
85-
var imageEmbeds = new DenseTensor<float>(performGuidance ? new[] { 2, 1, 768 } : new[] { 1, 1, 768 });
89+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
90+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
91+
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
92+
var imageEmbeds = new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, 768 });
8693

8794
var outputChannels = performGuidance ? 2 : 1;
88-
var outputDimension = inputTensor.Dimensions.ToArray(); //schedulerOptions.GetScaledDimension(outputChannels);
89-
using (var inferenceParameters = new OnnxInferenceParameters(metadataPrior))
95+
var outputDimension = inputTensor.Dimensions.ToArray();
96+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
9097
{
9198
inferenceParameters.AddInputTensor(inputTensor);
9299
inferenceParameters.AddInputTensor(timestepTensor);
@@ -105,58 +112,57 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
105112
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
106113

107114
// Scheduler Step
108-
latentsPrior = schedulerPrior.Step(noisePred, timestep, latentsPrior).Result;
115+
latents = scheduler.Step(noisePred, timestep, latents).Result;
109116
}
110117
}
111118

112-
ReportProgress(progressCallback, stepPrior, timestepsPrior.Count, latentsPrior);
113-
_logger?.LogEnd(LogLevel.Debug, $"Step {stepPrior}/{timestepsPrior.Count}", stepTime);
119+
ReportProgress(progressCallback, step, timesteps.Count, latents);
120+
_logger?.LogEnd(LogLevel.Debug, $"Prior Step {step}/{timesteps.Count}", stepTime);
114121
}
115122

116123
// Unload if required
117124
if (_memoryMode == MemoryModeType.Minimum)
118125
await _unet.UnloadAsync();
119126

127+
return latents;
128+
}
129+
}
120130

121131

122-
123-
124-
//----------------------------------------------------
125-
// Decoder Unet
126-
//====================================================
127-
132+
protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> latentsPrior, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
133+
{
134+
using (var scheduler = GetScheduler(schedulerOptions))
135+
{
128136
// Get timesteps
129-
var timestepsDecoder = GetTimesteps(schedulerOptions, schedulerDecoder);
137+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
130138

131139
// Create latent sample
132-
133-
var latentsDecoder = schedulerDecoder.CreateRandomSample(new[] { 1, 4, (int)(latentsPrior.Dimensions[2] * 10.67f), (int)(latentsPrior.Dimensions[3] * 10.67f) }, schedulerDecoder.InitNoiseSigma);
140+
var latents = scheduler.CreateRandomSample(new[] { 1, 4, (int)(latentsPrior.Dimensions[2] * 10.67f), (int)(latentsPrior.Dimensions[3] * 10.67f) }, scheduler.InitNoiseSigma);
134141

135142
// Get Model metadata
136-
var metadataDecoder = await _decoderUnet.GetMetadataAsync();
143+
var metadata = await _decoderUnet.GetMetadataAsync();
137144

138145
var effnet = performGuidance
139146
? latentsPrior
140147
: latentsPrior.Concatenate(new DenseTensor<float>(latentsPrior.Dimensions));
141148

142149

143150
// Loop though the timesteps
144-
var stepDecoder = 0;
145-
foreach (var timestep in timestepsDecoder)
151+
var step = 0;
152+
foreach (var timestep in timesteps)
146153
{
147-
stepDecoder++;
154+
step++;
148155
var stepTime = Stopwatch.GetTimestamp();
149156
cancellationToken.ThrowIfCancellationRequested();
150157

151158
// Create input tensor.
152-
var inputLatent = performGuidance ? latentsDecoder.Repeat(2) : latentsDecoder;
153-
var inputTensor = schedulerDecoder.ScaleInput(inputLatent, timestep);
159+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
160+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
154161
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
155162

156-
157163
var outputChannels = performGuidance ? 2 : 1;
158164
var outputDimension = inputTensor.Dimensions.ToArray(); //schedulerOptions.GetScaledDimension(outputChannels);
159-
using (var inferenceParameters = new OnnxInferenceParameters(metadataDecoder))
165+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
160166
{
161167
inferenceParameters.AddInputTensor(inputTensor);
162168
inferenceParameters.AddInputTensor(timestepTensor);
@@ -174,20 +180,19 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
174180
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
175181

176182
// Scheduler Step
177-
latentsDecoder = schedulerDecoder.Step(noisePred, timestep, latentsDecoder).Result;
183+
latents = scheduler.Step(noisePred, timestep, latents).Result;
178184
}
179185
}
180186

187+
ReportProgress(progressCallback, step, timesteps.Count, latents);
188+
_logger?.LogEnd(LogLevel.Debug, $"Decoder Step {step}/{timesteps.Count}", stepTime);
181189
}
182190

183-
var testlatentsPrior = new OnnxImage(latentsPrior);
184-
var testlatentsDecoder = new OnnxImage(latentsDecoder);
185-
await testlatentsPrior.SaveAsync("D:\\testlatentsPrior.png");
186-
await testlatentsDecoder.SaveAsync("D:\\latentsDecoder.png");
187-
191+
// Unload if required
192+
if (_memoryMode == MemoryModeType.Minimum)
193+
await _unet.UnloadAsync();
188194

189-
// Decode Latents
190-
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latentsDecoder);
195+
return latents;
191196
}
192197
}
193198

0 commit comments

Comments
 (0)