From e3ce0801e2e4f95b59bd4977f9f432e657d54a68 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Tue, 30 Apr 2024 14:31:19 +1200 Subject: [PATCH] ImageToImage --- OnnxStack.Core/Image/OnnxImage.cs | 32 +++++++- .../Diffusers/StableCascade/ImageDiffuser.cs | 81 +++++++++++++++++++ .../StableCascade/StableCascadeDiffuser.cs | 70 +++++++++++++--- .../Diffusers/StableCascade/TextDiffuser.cs | 23 +----- .../Pipelines/StableCascadePipeline.cs | 4 +- 5 files changed, 175 insertions(+), 35 deletions(-) create mode 100644 OnnxStack.StableDiffusion/Diffusers/StableCascade/ImageDiffuser.cs diff --git a/OnnxStack.Core/Image/OnnxImage.cs b/OnnxStack.Core/Image/OnnxImage.cs index c3db705..6cf0890 100644 --- a/OnnxStack.Core/Image/OnnxImage.cs +++ b/OnnxStack.Core/Image/OnnxImage.cs @@ -230,6 +230,30 @@ public DenseTensor GetImageTensor(ImageNormalizeType normalizeType = Imag } + public DenseTensor GetClipImageFeatureTensor() + { + var image = Clone(); + image.Resize(224, 224); + var imageArray = new DenseTensor(new[] { 1, 3, 224, 224 }); + var mean = new[] { 0.485f, 0.456f, 0.406f }; + var stddev = new[] { 0.229f, 0.224f, 0.225f }; + image.ProcessPixelRows(img => + { + for (int x = 0; x < image.Width; x++) + { + for (int y = 0; y < image.Height; y++) + { + var pixelSpan = img.GetRowSpan(y); + imageArray[0, 0, y, x] = ((pixelSpan[x].R / 255f) - mean[0]) / stddev[0]; + imageArray[0, 1, y, x] = ((pixelSpan[x].G / 255f) - mean[1]) / stddev[1]; + imageArray[0, 2, y, x] = ((pixelSpan[x].B / 255f) - mean[2]) / stddev[2]; + } + } + }); + return imageArray; + } + + /// /// Gets the image as tensor. /// @@ -293,7 +317,12 @@ public void Resize(int height, int width, ImageResizeMode resizeMode = ImageResi }); } - + public void ProcessPixelRows(PixelAccessorAction processPixels) + { + _imageData.ProcessPixelRows(processPixels); + } + + public OnnxImage Clone() { return new OnnxImage(_imageData); @@ -413,7 +442,6 @@ private DenseTensor NormalizeToOneToOne(ReadOnlySpan dimensions) return imageArray; } - /// /// Denormalizes the pixels from 0 to 1 to 0-255 /// diff --git a/OnnxStack.StableDiffusion/Diffusers/StableCascade/ImageDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableCascade/ImageDiffuser.cs new file mode 100644 index 0000000..22950d5 --- /dev/null +++ b/OnnxStack.StableDiffusion/Diffusers/StableCascade/ImageDiffuser.cs @@ -0,0 +1,81 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Model; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Models; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Diffusers.StableCascade +{ + public sealed class ImageDiffuser : StableCascadeDiffuser + { + + /// + /// Initializes a new instance of the class. + /// + /// The unet. + /// The vae decoder. + /// The vae encoder. + /// The logger. + public ImageDiffuser(UNetConditionModel priorUnet, UNetConditionModel decoderUnet, AutoEncoderModel decoderVqgan, AutoEncoderModel imageEncoder, MemoryModeType memoryMode, ILogger logger = default) + : base(priorUnet, decoderUnet, decoderVqgan, imageEncoder, memoryMode, logger) { } + + + /// + /// Gets the type of the diffuser. + /// + public override DiffuserType DiffuserType => DiffuserType.ImageToImage; + + + /// + /// Gets the timesteps. + /// + /// The options. + /// The scheduler. + /// + protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) + { + if (!options.Timesteps.IsNullOrEmpty()) + return options.Timesteps; + + return scheduler.Timesteps; + } + + + /// + /// Encodes the image. + /// + /// The prompt. + /// if set to true [perform guidance]. + /// + protected override async Task> EncodeImageAsync(PromptOptions prompt, bool performGuidance) + { + var metadata = await _vaeEncoder.GetMetadataAsync(); + var imageTensor = prompt.InputImage.GetClipImageFeatureTensor(); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(imageTensor); + inferenceParameters.AddOutputBuffer(new[] { 1, ClipImageChannels }); + + var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters); + using (var result = results.First()) + { + // Unload if required + if (_memoryMode == MemoryModeType.Minimum) + await _vaeEncoder.UnloadAsync(); + + var image_embeds = result.ToDenseTensor(new[] { 1, 1, ClipImageChannels }); + if (performGuidance) + return new DenseTensor(image_embeds.Dimensions).Concatenate(image_embeds); + + return image_embeds; + } + } + } + } +} diff --git a/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs index 61bbe65..d2ac4e6 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs @@ -62,15 +62,9 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de /// - /// Prepares the decoder latents. + /// Gets the clip image channels. /// - /// The prompt. - /// The options. - /// The scheduler. - /// The timesteps. - /// The prior latents. - /// - protected abstract Task> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps, DenseTensor priorLatents); + protected int ClipImageChannels => _clipImageChannels; /// @@ -135,6 +129,8 @@ protected async Task> DiffusePriorAsync(PromptOptions prompt, // Create latent sample var latents = await PrepareLatentsAsync(prompt, schedulerOptions, scheduler, timesteps); + var encodedImage = await EncodeImageAsync(prompt, performGuidance); + // Get Model metadata var metadata = await _unet.GetMetadataAsync(); @@ -150,14 +146,13 @@ protected async Task> DiffusePriorAsync(PromptOptions prompt, var inputLatent = performGuidance ? latents.Repeat(2) : latents; var inputTensor = scheduler.ScaleInput(inputLatent, timestep); var timestepTensor = CreateTimestepTensor(inputLatent, timestep); - var imageEmbeds = new DenseTensor(new[] { performGuidance ? 2 : 1, 1, _clipImageChannels }); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) { inferenceParameters.AddInputTensor(inputTensor); inferenceParameters.AddInputTensor(timestepTensor); inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds); inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); - inferenceParameters.AddInputTensor(imageEmbeds); + inferenceParameters.AddInputTensor(encodedImage); inferenceParameters.AddOutputBuffer(inputTensor.Dimensions); var results = await _unet.RunInferenceAsync(inferenceParameters); @@ -187,6 +182,8 @@ protected async Task> DiffusePriorAsync(PromptOptions prompt, } + + /// /// Run the Decoder UNET diffusion /// @@ -297,6 +294,59 @@ protected override async Task> DecodeLatentsAsync(PromptOptio } + /// + /// Prepares the input latents. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// The timesteps. + /// + protected override Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + var latents = scheduler.CreateRandomSample(new[] + { + 1, 16, + (int)Math.Ceiling(options.Height / ResolutionMultiple), + (int)Math.Ceiling(options.Width / ResolutionMultiple) + }, scheduler.InitNoiseSigma); + return Task.FromResult(latents); + } + + + /// + /// Prepares the decoder latents. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// The timesteps. + /// The prior latents. + /// + protected virtual Task> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps, DenseTensor priorLatents) + { + var latents = scheduler.CreateRandomSample(new[] + { + 1, 4, + (int)(priorLatents.Dimensions[2] * LatentDimScale), + (int)(priorLatents.Dimensions[3] * LatentDimScale) + }, scheduler.InitNoiseSigma); + return Task.FromResult(latents); + } + + + /// + /// Encodes the image. + /// + /// The prompt. + /// if set to true [perform guidance]. + /// + protected virtual Task> EncodeImageAsync(PromptOptions prompt, bool performGuidance) + { + return Task.FromResult(new DenseTensor(new[] { performGuidance ? 2 : 1, 1, _clipImageChannels })); + } + + /// /// Creates the timestep tensor. /// diff --git a/OnnxStack.StableDiffusion/Diffusers/StableCascade/TextDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableCascade/TextDiffuser.cs index 11cdbe4..73b1043 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableCascade/TextDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableCascade/TextDiffuser.cs @@ -48,27 +48,6 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc } - protected override Task> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) - { - var latents = scheduler.CreateRandomSample(new[] - { - 1, 16, - (int)Math.Ceiling(options.Height / ResolutionMultiple), - (int)Math.Ceiling(options.Width / ResolutionMultiple) - }, scheduler.InitNoiseSigma); - return Task.FromResult(latents); - } - - - protected override Task> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps, DenseTensor priorLatents) - { - var latents = scheduler.CreateRandomSample(new[] - { - 1, 4, - (int)(priorLatents.Dimensions[2] * LatentDimScale), - (int)(priorLatents.Dimensions[3] * LatentDimScale) - }, scheduler.InitNoiseSigma); - return Task.FromResult(latents); - } + } } diff --git a/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs index a3a8cf7..004b90f 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs @@ -40,7 +40,8 @@ public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tok _decoderUnet = decoderUnet; _supportedDiffusers = diffusers ?? new List { - DiffuserType.TextToImage + DiffuserType.TextToImage, + DiffuserType.ImageToImage }; _supportedSchedulers = new List { @@ -99,6 +100,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe return diffuserType switch { DiffuserType.TextToImage => new TextDiffuser(_unet, _decoderUnet, _vaeDecoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ImageToImage => new ImageDiffuser(_unet, _decoderUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), _ => throw new NotImplementedException() }; }