1
1
using Microsoft . Extensions . Logging ;
2
2
using Microsoft . ML . OnnxRuntime . Tensors ;
3
3
using OnnxStack . Core ;
4
- using OnnxStack . Core . Image ;
5
4
using OnnxStack . Core . Model ;
6
5
using OnnxStack . StableDiffusion . Common ;
7
6
using OnnxStack . StableDiffusion . Config ;
@@ -53,40 +52,48 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
53
52
/// <returns></returns>
54
53
public override async Task < DenseTensor < float > > DiffuseAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
55
54
{
56
- // Get Scheduler
57
- using ( var schedulerPrior = GetScheduler ( schedulerOptions ) )
58
- using ( var schedulerDecoder = GetScheduler ( schedulerOptions with { InferenceSteps = 10 , GuidanceScale = 0 } ) )
59
- {
60
- //----------------------------------------------------
61
- // Prior Unet
62
- //====================================================
55
+ // Prior Unet
56
+ var latentsPrior = await DiffusePriorAsync ( schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
57
+
58
+ // Decoder Unet
59
+ var schedulerOptionsDecoder = schedulerOptions with { InferenceSteps = 10 , GuidanceScale = 0 } ;
60
+ var latents = await DiffuseDecodeAsync ( latentsPrior , schedulerOptionsDecoder , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
61
+
62
+ // Decode Latents
63
+ return await DecodeLatentsAsync ( promptOptions , schedulerOptions , latents ) ;
64
+ }
65
+
63
66
67
+ protected async Task < DenseTensor < float > > DiffusePriorAsync ( SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
68
+ {
69
+ using ( var scheduler = GetScheduler ( schedulerOptions ) )
70
+ {
64
71
// Get timesteps
65
- var timestepsPrior = GetTimesteps ( schedulerOptions , schedulerPrior ) ;
72
+ var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
66
73
67
74
// Create latent sample
68
- var latentsPrior = schedulerPrior . CreateRandomSample ( new [ ] { 1 , 16 , ( int ) Math . Ceiling ( schedulerOptions . Height / 42.67f ) , ( int ) Math . Ceiling ( schedulerOptions . Width / 42.67f ) } , schedulerPrior . InitNoiseSigma ) ;
75
+ var latents = scheduler . CreateRandomSample ( new [ ] { 1 , 16 , ( int ) Math . Ceiling ( schedulerOptions . Height / 42.67f ) , ( int ) Math . Ceiling ( schedulerOptions . Width / 42.67f ) } , scheduler . InitNoiseSigma ) ;
69
76
70
77
// Get Model metadata
71
- var metadataPrior = await _unet . GetMetadataAsync ( ) ;
78
+ var metadata = await _unet . GetMetadataAsync ( ) ;
72
79
73
80
// Loop though the timesteps
74
- var stepPrior = 0 ;
75
- foreach ( var timestep in timestepsPrior )
81
+ var step = 0 ;
82
+ foreach ( var timestep in timesteps )
76
83
{
77
- stepPrior ++ ;
84
+ step ++ ;
78
85
var stepTime = Stopwatch . GetTimestamp ( ) ;
79
86
cancellationToken . ThrowIfCancellationRequested ( ) ;
80
87
81
88
// Create input tensor.
82
- var inputLatent = performGuidance ? latentsPrior . Repeat ( 2 ) : latentsPrior ;
83
- var inputTensor = schedulerPrior . ScaleInput ( inputLatent , timestep ) ;
84
- var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
85
- var imageEmbeds = new DenseTensor < float > ( performGuidance ? new [ ] { 2 , 1 , 768 } : new [ ] { 1 , 1 , 768 } ) ;
89
+ var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
90
+ var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
91
+ var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
92
+ var imageEmbeds = new DenseTensor < float > ( new [ ] { performGuidance ? 2 : 1 , 1 , 768 } ) ;
86
93
87
94
var outputChannels = performGuidance ? 2 : 1 ;
88
- var outputDimension = inputTensor . Dimensions . ToArray ( ) ; //schedulerOptions.GetScaledDimension(outputChannels);
89
- using ( var inferenceParameters = new OnnxInferenceParameters ( metadataPrior ) )
95
+ var outputDimension = inputTensor . Dimensions . ToArray ( ) ;
96
+ using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
90
97
{
91
98
inferenceParameters . AddInputTensor ( inputTensor ) ;
92
99
inferenceParameters . AddInputTensor ( timestepTensor ) ;
@@ -105,58 +112,57 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
105
112
noisePred = PerformGuidance ( noisePred , schedulerOptions . GuidanceScale ) ;
106
113
107
114
// Scheduler Step
108
- latentsPrior = schedulerPrior . Step ( noisePred , timestep , latentsPrior ) . Result ;
115
+ latents = scheduler . Step ( noisePred , timestep , latents ) . Result ;
109
116
}
110
117
}
111
118
112
- ReportProgress ( progressCallback , stepPrior , timestepsPrior . Count , latentsPrior ) ;
113
- _logger ? . LogEnd ( LogLevel . Debug , $ "Step { stepPrior } /{ timestepsPrior . Count } ", stepTime ) ;
119
+ ReportProgress ( progressCallback , step , timesteps . Count , latents ) ;
120
+ _logger ? . LogEnd ( LogLevel . Debug , $ "Prior Step { step } /{ timesteps . Count } ", stepTime ) ;
114
121
}
115
122
116
123
// Unload if required
117
124
if ( _memoryMode == MemoryModeType . Minimum )
118
125
await _unet . UnloadAsync ( ) ;
119
126
127
+ return latents ;
128
+ }
129
+ }
120
130
121
131
122
-
123
-
124
- //----------------------------------------------------
125
- // Decoder Unet
126
- //====================================================
127
-
132
+ protected async Task < DenseTensor < float > > DiffuseDecodeAsync ( DenseTensor < float > latentsPrior , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
133
+ {
134
+ using ( var scheduler = GetScheduler ( schedulerOptions ) )
135
+ {
128
136
// Get timesteps
129
- var timestepsDecoder = GetTimesteps ( schedulerOptions , schedulerDecoder ) ;
137
+ var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
130
138
131
139
// Create latent sample
132
-
133
- var latentsDecoder = schedulerDecoder . CreateRandomSample ( new [ ] { 1 , 4 , ( int ) ( latentsPrior . Dimensions [ 2 ] * 10.67f ) , ( int ) ( latentsPrior . Dimensions [ 3 ] * 10.67f ) } , schedulerDecoder . InitNoiseSigma ) ;
140
+ var latents = scheduler . CreateRandomSample ( new [ ] { 1 , 4 , ( int ) ( latentsPrior . Dimensions [ 2 ] * 10.67f ) , ( int ) ( latentsPrior . Dimensions [ 3 ] * 10.67f ) } , scheduler . InitNoiseSigma ) ;
134
141
135
142
// Get Model metadata
136
- var metadataDecoder = await _decoderUnet . GetMetadataAsync ( ) ;
143
+ var metadata = await _decoderUnet . GetMetadataAsync ( ) ;
137
144
138
145
var effnet = performGuidance
139
146
? latentsPrior
140
147
: latentsPrior . Concatenate ( new DenseTensor < float > ( latentsPrior . Dimensions ) ) ;
141
148
142
149
143
150
// Loop though the timesteps
144
- var stepDecoder = 0 ;
145
- foreach ( var timestep in timestepsDecoder )
151
+ var step = 0 ;
152
+ foreach ( var timestep in timesteps )
146
153
{
147
- stepDecoder ++ ;
154
+ step ++ ;
148
155
var stepTime = Stopwatch . GetTimestamp ( ) ;
149
156
cancellationToken . ThrowIfCancellationRequested ( ) ;
150
157
151
158
// Create input tensor.
152
- var inputLatent = performGuidance ? latentsDecoder . Repeat ( 2 ) : latentsDecoder ;
153
- var inputTensor = schedulerDecoder . ScaleInput ( inputLatent , timestep ) ;
159
+ var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
160
+ var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
154
161
var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
155
162
156
-
157
163
var outputChannels = performGuidance ? 2 : 1 ;
158
164
var outputDimension = inputTensor . Dimensions . ToArray ( ) ; //schedulerOptions.GetScaledDimension(outputChannels);
159
- using ( var inferenceParameters = new OnnxInferenceParameters ( metadataDecoder ) )
165
+ using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
160
166
{
161
167
inferenceParameters . AddInputTensor ( inputTensor ) ;
162
168
inferenceParameters . AddInputTensor ( timestepTensor ) ;
@@ -174,20 +180,19 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
174
180
noisePred = PerformGuidance ( noisePred , schedulerOptions . GuidanceScale ) ;
175
181
176
182
// Scheduler Step
177
- latentsDecoder = schedulerDecoder . Step ( noisePred , timestep , latentsDecoder ) . Result ;
183
+ latents = scheduler . Step ( noisePred , timestep , latents ) . Result ;
178
184
}
179
185
}
180
186
187
+ ReportProgress ( progressCallback , step , timesteps . Count , latents ) ;
188
+ _logger ? . LogEnd ( LogLevel . Debug , $ "Decoder Step { step } /{ timesteps . Count } ", stepTime ) ;
181
189
}
182
190
183
- var testlatentsPrior = new OnnxImage ( latentsPrior ) ;
184
- var testlatentsDecoder = new OnnxImage ( latentsDecoder ) ;
185
- await testlatentsPrior . SaveAsync ( "D:\\ testlatentsPrior.png" ) ;
186
- await testlatentsDecoder . SaveAsync ( "D:\\ latentsDecoder.png" ) ;
187
-
191
+ // Unload if required
192
+ if ( _memoryMode == MemoryModeType . Minimum )
193
+ await _unet . UnloadAsync ( ) ;
188
194
189
- // Decode Latents
190
- return await DecodeLatentsAsync ( promptOptions , schedulerOptions , latentsDecoder ) ;
195
+ return latents ;
191
196
}
192
197
}
193
198
0 commit comments