8
8
using OnnxStack . StableDiffusion . Models ;
9
9
using OnnxStack . StableDiffusion . Schedulers . StableDiffusion ;
10
10
using System ;
11
+ using System . Collections . Generic ;
11
12
using System . Diagnostics ;
12
13
using System . Linq ;
13
14
using System . Threading ;
@@ -17,6 +18,9 @@ namespace OnnxStack.StableDiffusion.Diffusers.StableCascade
17
18
{
18
19
public abstract class StableCascadeDiffuser : DiffuserBase
19
20
{
21
+ private readonly float _latentDimScale ;
22
+ private readonly float _resolutionMultiple ;
23
+ private readonly int _clipImageChannels ;
20
24
private readonly UNetConditionModel _decoderUnet ;
21
25
22
26
/// <summary>
@@ -32,6 +36,9 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
32
36
: base ( priorUnet , decoderVqgan , imageEncoder , memoryMode , logger )
33
37
{
34
38
_decoderUnet = decoderUnet ;
39
+ _latentDimScale = 10.67f ;
40
+ _resolutionMultiple = 42.67f ;
41
+ _clipImageChannels = 768 ;
35
42
}
36
43
37
44
/// <summary>
@@ -40,6 +47,32 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
40
47
public override DiffuserPipelineType PipelineType => DiffuserPipelineType . StableCascade ;
41
48
42
49
50
+ /// <summary>
51
+ /// Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
52
+ /// height=24 and width = 24, the VQ latent shape needs to be height=int (24*10.67)=256 and
53
+ /// width = int(24 * 10.67) = 256 in order to match the training conditions.
54
+ /// </summary>
55
+ protected float LatentDimScale => _latentDimScale ;
56
+
57
+
58
+ /// <summary>
59
+ /// Default resolution for multiple images generated
60
+ /// </summary>
61
+ protected float ResolutionMultiple => _resolutionMultiple ;
62
+
63
+
64
+ /// <summary>
65
+ /// Prepares the decoder latents.
66
+ /// </summary>
67
+ /// <param name="prompt">The prompt.</param>
68
+ /// <param name="options">The options.</param>
69
+ /// <param name="scheduler">The scheduler.</param>
70
+ /// <param name="timesteps">The timesteps.</param>
71
+ /// <param name="priorLatents">The prior latents.</param>
72
+ /// <returns></returns>
73
+ protected abstract Task < DenseTensor < float > > PrepareDecoderLatentsAsync ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler , IReadOnlyList < int > timesteps , DenseTensor < float > priorLatents ) ;
74
+
75
+
43
76
/// <summary>
44
77
/// Runs the scheduler steps.
45
78
/// </summary>
@@ -52,27 +85,55 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
52
85
/// <returns></returns>
53
86
public override async Task < DenseTensor < float > > DiffuseAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
54
87
{
88
+ var decodeSchedulerOptions = schedulerOptions with
89
+ {
90
+ InferenceSteps = schedulerOptions . InferenceSteps2 ,
91
+ GuidanceScale = schedulerOptions . GuidanceScale2
92
+ } ;
93
+
94
+ var priorPromptEmbeddings = promptEmbeddings ;
95
+ var decoderPromptEmbeddings = promptEmbeddings ;
96
+ var priorPerformGuidance = schedulerOptions . GuidanceScale > 0 ;
97
+ var decoderPerformGuidance = decodeSchedulerOptions . GuidanceScale > 0 ;
98
+ if ( performGuidance )
99
+ {
100
+ if ( ! priorPerformGuidance )
101
+ priorPromptEmbeddings = SplitPromptEmbeddings ( promptEmbeddings ) ;
102
+ if ( ! decoderPerformGuidance )
103
+ decoderPromptEmbeddings = SplitPromptEmbeddings ( promptEmbeddings ) ;
104
+ }
105
+
55
106
// Prior Unet
56
- var latentsPrior = await DiffusePriorAsync ( schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
107
+ var priorLatents = await DiffusePriorAsync ( promptOptions , schedulerOptions , priorPromptEmbeddings , priorPerformGuidance , progressCallback , cancellationToken ) ;
57
108
58
109
// Decoder Unet
59
- var schedulerOptionsDecoder = schedulerOptions with { InferenceSteps = 10 , GuidanceScale = 0 } ;
60
- var latents = await DiffuseDecodeAsync ( latentsPrior , schedulerOptionsDecoder , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
110
+ var decoderLatents = await DiffuseDecodeAsync ( promptOptions , priorLatents , decodeSchedulerOptions , decoderPromptEmbeddings , decoderPerformGuidance , progressCallback , cancellationToken ) ;
61
111
62
112
// Decode Latents
63
- return await DecodeLatentsAsync ( promptOptions , schedulerOptions , latents ) ;
113
+ return await DecodeLatentsAsync ( promptOptions , schedulerOptions , decoderLatents ) ;
64
114
}
65
115
66
116
67
- protected async Task < DenseTensor < float > > DiffusePriorAsync ( SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
117
+
118
+ /// <summary>
119
+ /// Run the Prior UNET diffusion
120
+ /// </summary>
121
+ /// <param name="prompt">The prompt.</param>
122
+ /// <param name="schedulerOptions">The scheduler options.</param>
123
+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
124
+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
125
+ /// <param name="progressCallback">The progress callback.</param>
126
+ /// <param name="cancellationToken">The cancellation token.</param>
127
+ /// <returns></returns>
128
+ protected async Task < DenseTensor < float > > DiffusePriorAsync ( PromptOptions prompt , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
68
129
{
69
130
using ( var scheduler = GetScheduler ( schedulerOptions ) )
70
131
{
71
132
// Get timesteps
72
133
var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
73
134
74
135
// Create latent sample
75
- var latents = scheduler . CreateRandomSample ( new [ ] { 1 , 16 , ( int ) Math . Ceiling ( schedulerOptions . Height / 42.67f ) , ( int ) Math . Ceiling ( schedulerOptions . Width / 42.67f ) } , scheduler . InitNoiseSigma ) ;
136
+ var latents = await PrepareLatentsAsync ( prompt , schedulerOptions , scheduler , timesteps ) ;
76
137
77
138
// Get Model metadata
78
139
var metadata = await _unet . GetMetadataAsync ( ) ;
@@ -89,18 +150,15 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions sche
89
150
var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
90
151
var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
91
152
var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
92
- var imageEmbeds = new DenseTensor < float > ( new [ ] { performGuidance ? 2 : 1 , 1 , 768 } ) ;
93
-
94
- var outputChannels = performGuidance ? 2 : 1 ;
95
- var outputDimension = inputTensor . Dimensions . ToArray ( ) ;
153
+ var imageEmbeds = new DenseTensor < float > ( new [ ] { performGuidance ? 2 : 1 , 1 , _clipImageChannels } ) ;
96
154
using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
97
155
{
98
156
inferenceParameters . AddInputTensor ( inputTensor ) ;
99
157
inferenceParameters . AddInputTensor ( timestepTensor ) ;
100
158
inferenceParameters . AddInputTensor ( promptEmbeddings . PooledPromptEmbeds ) ;
101
159
inferenceParameters . AddInputTensor ( promptEmbeddings . PromptEmbeds ) ;
102
160
inferenceParameters . AddInputTensor ( imageEmbeds ) ;
103
- inferenceParameters . AddOutputBuffer ( outputDimension ) ;
161
+ inferenceParameters . AddOutputBuffer ( inputTensor . Dimensions ) ;
104
162
105
163
var results = await _unet . RunInferenceAsync ( inferenceParameters ) ;
106
164
using ( var result = results . First ( ) )
@@ -129,23 +187,33 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions sche
129
187
}
130
188
131
189
132
- protected async Task < DenseTensor < float > > DiffuseDecodeAsync ( DenseTensor < float > latentsPrior , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
190
+ /// <summary>
191
+ /// Run the Decoder UNET diffusion
192
+ /// </summary>
193
+ /// <param name="prompt">The prompt.</param>
194
+ /// <param name="priorLatents">The prior latents.</param>
195
+ /// <param name="schedulerOptions">The scheduler options.</param>
196
+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
197
+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
198
+ /// <param name="progressCallback">The progress callback.</param>
199
+ /// <param name="cancellationToken">The cancellation token.</param>
200
+ /// <returns></returns>
201
+ protected async Task < DenseTensor < float > > DiffuseDecodeAsync ( PromptOptions prompt , DenseTensor < float > priorLatents , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
133
202
{
134
203
using ( var scheduler = GetScheduler ( schedulerOptions ) )
135
204
{
136
205
// Get timesteps
137
206
var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
138
207
139
208
// Create latent sample
140
- var latents = scheduler . CreateRandomSample ( new [ ] { 1 , 4 , ( int ) ( latentsPrior . Dimensions [ 2 ] * 10.67f ) , ( int ) ( latentsPrior . Dimensions [ 3 ] * 10.67f ) } , scheduler . InitNoiseSigma ) ;
209
+ var latents = await PrepareDecoderLatentsAsync ( prompt , schedulerOptions , scheduler , timesteps , priorLatents ) ;
141
210
142
211
// Get Model metadata
143
212
var metadata = await _decoderUnet . GetMetadataAsync ( ) ;
144
213
145
- var effnet = performGuidance
146
- ? latentsPrior
147
- : latentsPrior . Concatenate ( new DenseTensor < float > ( latentsPrior . Dimensions ) ) ;
148
-
214
+ var effnet = ! performGuidance
215
+ ? priorLatents
216
+ : priorLatents . Repeat ( 2 ) ;
149
217
150
218
// Loop though the timesteps
151
219
var step = 0 ;
@@ -159,18 +227,15 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> l
159
227
var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
160
228
var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
161
229
var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
162
-
163
- var outputChannels = performGuidance ? 2 : 1 ;
164
- var outputDimension = inputTensor . Dimensions . ToArray ( ) ; //schedulerOptions.GetScaledDimension(outputChannels);
165
230
using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
166
231
{
167
232
inferenceParameters . AddInputTensor ( inputTensor ) ;
168
233
inferenceParameters . AddInputTensor ( timestepTensor ) ;
169
234
inferenceParameters . AddInputTensor ( promptEmbeddings . PooledPromptEmbeds ) ;
170
235
inferenceParameters . AddInputTensor ( effnet ) ;
171
- inferenceParameters . AddOutputBuffer ( ) ;
236
+ inferenceParameters . AddOutputBuffer ( inputTensor . Dimensions ) ;
172
237
173
- var results = _decoderUnet . RunInference ( inferenceParameters ) ;
238
+ var results = await _decoderUnet . RunInferenceAsync ( inferenceParameters ) ;
174
239
using ( var result = results . First ( ) )
175
240
{
176
241
var noisePred = result . ToDenseTensor ( ) ;
@@ -197,6 +262,13 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> l
197
262
}
198
263
199
264
265
+ /// <summary>
266
+ /// Decodes the latents.
267
+ /// </summary>
268
+ /// <param name="prompt">The prompt.</param>
269
+ /// <param name="options">The options.</param>
270
+ /// <param name="latents">The latents.</param>
271
+ /// <returns></returns>
200
272
protected override async Task < DenseTensor < float > > DecodeLatentsAsync ( PromptOptions prompt , SchedulerOptions options , DenseTensor < float > latents )
201
273
{
202
274
latents = latents . MultiplyBy ( _vaeDecoder . ScaleFactor ) ;
@@ -239,6 +311,19 @@ private DenseTensor<float> CreateTimestepTensor(DenseTensor<float> latents, int
239
311
}
240
312
241
313
314
+ /// <summary>
315
+ /// Splits the prompt embeddings, Removes unconditional embeddings
316
+ /// </summary>
317
+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
318
+ /// <returns></returns>
319
+ private PromptEmbeddingsResult SplitPromptEmbeddings ( PromptEmbeddingsResult promptEmbeddings )
320
+ {
321
+ return promptEmbeddings . PooledPromptEmbeds is null
322
+ ? new PromptEmbeddingsResult ( promptEmbeddings . PromptEmbeds . SplitBatch ( ) . Last ( ) )
323
+ : new PromptEmbeddingsResult ( promptEmbeddings . PromptEmbeds . SplitBatch ( ) . Last ( ) , promptEmbeddings . PooledPromptEmbeds . SplitBatch ( ) . Last ( ) ) ;
324
+ }
325
+
326
+
242
327
/// <summary>
243
328
/// Gets the scheduler.
244
329
/// </summary>
0 commit comments