Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit e3ce080

Browse files
committed
ImageToImage
1 parent 2ab0fce commit e3ce080

File tree

5 files changed

+175
-35
lines changed

5 files changed

+175
-35
lines changed

OnnxStack.Core/Image/OnnxImage.cs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,30 @@ public DenseTensor<float> GetImageTensor(ImageNormalizeType normalizeType = Imag
230230
}
231231

232232

233+
public DenseTensor<float> GetClipImageFeatureTensor()
234+
{
235+
var image = Clone();
236+
image.Resize(224, 224);
237+
var imageArray = new DenseTensor<float>(new[] { 1, 3, 224, 224 });
238+
var mean = new[] { 0.485f, 0.456f, 0.406f };
239+
var stddev = new[] { 0.229f, 0.224f, 0.225f };
240+
image.ProcessPixelRows(img =>
241+
{
242+
for (int x = 0; x < image.Width; x++)
243+
{
244+
for (int y = 0; y < image.Height; y++)
245+
{
246+
var pixelSpan = img.GetRowSpan(y);
247+
imageArray[0, 0, y, x] = ((pixelSpan[x].R / 255f) - mean[0]) / stddev[0];
248+
imageArray[0, 1, y, x] = ((pixelSpan[x].G / 255f) - mean[1]) / stddev[1];
249+
imageArray[0, 2, y, x] = ((pixelSpan[x].B / 255f) - mean[2]) / stddev[2];
250+
}
251+
}
252+
});
253+
return imageArray;
254+
}
255+
256+
233257
/// <summary>
234258
/// Gets the image as tensor.
235259
/// </summary>
@@ -293,7 +317,12 @@ public void Resize(int height, int width, ImageResizeMode resizeMode = ImageResi
293317
});
294318
}
295319

296-
320+
public void ProcessPixelRows(PixelAccessorAction<Rgba32> processPixels)
321+
{
322+
_imageData.ProcessPixelRows(processPixels);
323+
}
324+
325+
297326
public OnnxImage Clone()
298327
{
299328
return new OnnxImage(_imageData);
@@ -413,7 +442,6 @@ private DenseTensor<float> NormalizeToOneToOne(ReadOnlySpan<int> dimensions)
413442
return imageArray;
414443
}
415444

416-
417445
/// <summary>
418446
/// Denormalizes the pixels from 0 to 1 to 0-255
419447
/// </summary>
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Model;
5+
using OnnxStack.StableDiffusion.Common;
6+
using OnnxStack.StableDiffusion.Config;
7+
using OnnxStack.StableDiffusion.Enums;
8+
using OnnxStack.StableDiffusion.Models;
9+
using System.Collections.Generic;
10+
using System.Linq;
11+
using System.Threading.Tasks;
12+
13+
namespace OnnxStack.StableDiffusion.Diffusers.StableCascade
14+
{
15+
public sealed class ImageDiffuser : StableCascadeDiffuser
16+
{
17+
18+
/// <summary>
19+
/// Initializes a new instance of the <see cref="ImageDiffuser"/> class.
20+
/// </summary>
21+
/// <param name="unet">The unet.</param>
22+
/// <param name="vaeDecoder">The vae decoder.</param>
23+
/// <param name="vaeEncoder">The vae encoder.</param>
24+
/// <param name="logger">The logger.</param>
25+
public ImageDiffuser(UNetConditionModel priorUnet, UNetConditionModel decoderUnet, AutoEncoderModel decoderVqgan, AutoEncoderModel imageEncoder, MemoryModeType memoryMode, ILogger logger = default)
26+
: base(priorUnet, decoderUnet, decoderVqgan, imageEncoder, memoryMode, logger) { }
27+
28+
29+
/// <summary>
30+
/// Gets the type of the diffuser.
31+
/// </summary>
32+
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;
33+
34+
35+
/// <summary>
36+
/// Gets the timesteps.
37+
/// </summary>
38+
/// <param name="options">The options.</param>
39+
/// <param name="scheduler">The scheduler.</param>
40+
/// <returns></returns>
41+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
42+
{
43+
if (!options.Timesteps.IsNullOrEmpty())
44+
return options.Timesteps;
45+
46+
return scheduler.Timesteps;
47+
}
48+
49+
50+
/// <summary>
51+
/// Encodes the image.
52+
/// </summary>
53+
/// <param name="prompt">The prompt.</param>
54+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
55+
/// <returns></returns>
56+
protected override async Task<DenseTensor<float>> EncodeImageAsync(PromptOptions prompt, bool performGuidance)
57+
{
58+
var metadata = await _vaeEncoder.GetMetadataAsync();
59+
var imageTensor = prompt.InputImage.GetClipImageFeatureTensor();
60+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
61+
{
62+
inferenceParameters.AddInputTensor(imageTensor);
63+
inferenceParameters.AddOutputBuffer(new[] { 1, ClipImageChannels });
64+
65+
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
66+
using (var result = results.First())
67+
{
68+
// Unload if required
69+
if (_memoryMode == MemoryModeType.Minimum)
70+
await _vaeEncoder.UnloadAsync();
71+
72+
var image_embeds = result.ToDenseTensor(new[] { 1, 1, ClipImageChannels });
73+
if (performGuidance)
74+
return new DenseTensor<float>(image_embeds.Dimensions).Concatenate(image_embeds);
75+
76+
return image_embeds;
77+
}
78+
}
79+
}
80+
}
81+
}

OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,9 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
6262

6363

6464
/// <summary>
65-
/// Prepares the decoder latents.
65+
/// Gets the clip image channels.
6666
/// </summary>
67-
/// <param name="prompt">The prompt.</param>
68-
/// <param name="options">The options.</param>
69-
/// <param name="scheduler">The scheduler.</param>
70-
/// <param name="timesteps">The timesteps.</param>
71-
/// <param name="priorLatents">The prior latents.</param>
72-
/// <returns></returns>
73-
protected abstract Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents);
67+
protected int ClipImageChannels => _clipImageChannels;
7468

7569

7670
/// <summary>
@@ -135,6 +129,8 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(PromptOptions prompt,
135129
// Create latent sample
136130
var latents = await PrepareLatentsAsync(prompt, schedulerOptions, scheduler, timesteps);
137131

132+
var encodedImage = await EncodeImageAsync(prompt, performGuidance);
133+
138134
// Get Model metadata
139135
var metadata = await _unet.GetMetadataAsync();
140136

@@ -150,14 +146,13 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(PromptOptions prompt,
150146
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
151147
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
152148
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
153-
var imageEmbeds = new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, _clipImageChannels });
154149
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
155150
{
156151
inferenceParameters.AddInputTensor(inputTensor);
157152
inferenceParameters.AddInputTensor(timestepTensor);
158153
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
159154
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
160-
inferenceParameters.AddInputTensor(imageEmbeds);
155+
inferenceParameters.AddInputTensor(encodedImage);
161156
inferenceParameters.AddOutputBuffer(inputTensor.Dimensions);
162157

163158
var results = await _unet.RunInferenceAsync(inferenceParameters);
@@ -187,6 +182,8 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(PromptOptions prompt,
187182
}
188183

189184

185+
186+
190187
/// <summary>
191188
/// Run the Decoder UNET diffusion
192189
/// </summary>
@@ -297,6 +294,59 @@ protected override async Task<DenseTensor<float>> DecodeLatentsAsync(PromptOptio
297294
}
298295

299296

297+
/// <summary>
298+
/// Prepares the input latents.
299+
/// </summary>
300+
/// <param name="prompt">The prompt.</param>
301+
/// <param name="options">The options.</param>
302+
/// <param name="scheduler">The scheduler.</param>
303+
/// <param name="timesteps">The timesteps.</param>
304+
/// <returns></returns>
305+
protected override Task<DenseTensor<float>> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
306+
{
307+
var latents = scheduler.CreateRandomSample(new[]
308+
{
309+
1, 16,
310+
(int)Math.Ceiling(options.Height / ResolutionMultiple),
311+
(int)Math.Ceiling(options.Width / ResolutionMultiple)
312+
}, scheduler.InitNoiseSigma);
313+
return Task.FromResult(latents);
314+
}
315+
316+
317+
/// <summary>
318+
/// Prepares the decoder latents.
319+
/// </summary>
320+
/// <param name="prompt">The prompt.</param>
321+
/// <param name="options">The options.</param>
322+
/// <param name="scheduler">The scheduler.</param>
323+
/// <param name="timesteps">The timesteps.</param>
324+
/// <param name="priorLatents">The prior latents.</param>
325+
/// <returns></returns>
326+
protected virtual Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents)
327+
{
328+
var latents = scheduler.CreateRandomSample(new[]
329+
{
330+
1, 4,
331+
(int)(priorLatents.Dimensions[2] * LatentDimScale),
332+
(int)(priorLatents.Dimensions[3] * LatentDimScale)
333+
}, scheduler.InitNoiseSigma);
334+
return Task.FromResult(latents);
335+
}
336+
337+
338+
/// <summary>
339+
/// Encodes the image.
340+
/// </summary>
341+
/// <param name="prompt">The prompt.</param>
342+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
343+
/// <returns></returns>
344+
protected virtual Task<DenseTensor<float>> EncodeImageAsync(PromptOptions prompt, bool performGuidance)
345+
{
346+
return Task.FromResult(new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, _clipImageChannels }));
347+
}
348+
349+
300350
/// <summary>
301351
/// Creates the timestep tensor.
302352
/// </summary>

OnnxStack.StableDiffusion/Diffusers/StableCascade/TextDiffuser.cs

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,6 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
4848
}
4949

5050

51-
protected override Task<DenseTensor<float>> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
52-
{
53-
var latents = scheduler.CreateRandomSample(new[]
54-
{
55-
1, 16,
56-
(int)Math.Ceiling(options.Height / ResolutionMultiple),
57-
(int)Math.Ceiling(options.Width / ResolutionMultiple)
58-
}, scheduler.InitNoiseSigma);
59-
return Task.FromResult(latents);
60-
}
61-
62-
63-
protected override Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents)
64-
{
65-
var latents = scheduler.CreateRandomSample(new[]
66-
{
67-
1, 4,
68-
(int)(priorLatents.Dimensions[2] * LatentDimScale),
69-
(int)(priorLatents.Dimensions[3] * LatentDimScale)
70-
}, scheduler.InitNoiseSigma);
71-
return Task.FromResult(latents);
72-
}
51+
7352
}
7453
}

OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tok
4040
_decoderUnet = decoderUnet;
4141
_supportedDiffusers = diffusers ?? new List<DiffuserType>
4242
{
43-
DiffuserType.TextToImage
43+
DiffuserType.TextToImage,
44+
DiffuserType.ImageToImage
4445
};
4546
_supportedSchedulers = new List<SchedulerType>
4647
{
@@ -99,6 +100,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
99100
return diffuserType switch
100101
{
101102
DiffuserType.TextToImage => new TextDiffuser(_unet, _decoderUnet, _vaeDecoder, _pipelineOptions.MemoryMode, _logger),
103+
DiffuserType.ImageToImage => new ImageDiffuser(_unet, _decoderUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
102104
_ => throw new NotImplementedException()
103105
};
104106
}

0 commit comments

Comments
 (0)