Skip to content

Commit

Permalink
ImageToImage
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed Apr 30, 2024
1 parent 2ab0fce commit e3ce080
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 35 deletions.
32 changes: 30 additions & 2 deletions OnnxStack.Core/Image/OnnxImage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,30 @@ public DenseTensor<float> GetImageTensor(ImageNormalizeType normalizeType = Imag
}


public DenseTensor<float> GetClipImageFeatureTensor()
{
var image = Clone();
image.Resize(224, 224);
var imageArray = new DenseTensor<float>(new[] { 1, 3, 224, 224 });
var mean = new[] { 0.485f, 0.456f, 0.406f };
var stddev = new[] { 0.229f, 0.224f, 0.225f };
image.ProcessPixelRows(img =>
{
for (int x = 0; x < image.Width; x++)
{
for (int y = 0; y < image.Height; y++)
{
var pixelSpan = img.GetRowSpan(y);
imageArray[0, 0, y, x] = ((pixelSpan[x].R / 255f) - mean[0]) / stddev[0];
imageArray[0, 1, y, x] = ((pixelSpan[x].G / 255f) - mean[1]) / stddev[1];
imageArray[0, 2, y, x] = ((pixelSpan[x].B / 255f) - mean[2]) / stddev[2];
}
}
});
return imageArray;
}


/// <summary>
/// Gets the image as tensor.
/// </summary>
Expand Down Expand Up @@ -293,7 +317,12 @@ public void Resize(int height, int width, ImageResizeMode resizeMode = ImageResi
});
}


public void ProcessPixelRows(PixelAccessorAction<Rgba32> processPixels)
{
_imageData.ProcessPixelRows(processPixels);
}


public OnnxImage Clone()
{
return new OnnxImage(_imageData);
Expand Down Expand Up @@ -413,7 +442,6 @@ private DenseTensor<float> NormalizeToOneToOne(ReadOnlySpan<int> dimensions)
return imageArray;
}


/// <summary>
/// Denormalizes the pixels from 0 to 1 to 0-255
/// </summary>
Expand Down
81 changes: 81 additions & 0 deletions OnnxStack.StableDiffusion/Diffusers/StableCascade/ImageDiffuser.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.Core.Model;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Models;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace OnnxStack.StableDiffusion.Diffusers.StableCascade
{
public sealed class ImageDiffuser : StableCascadeDiffuser
{

/// <summary>
/// Initializes a new instance of the <see cref="ImageDiffuser"/> class.
/// </summary>
/// <param name="unet">The unet.</param>
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public ImageDiffuser(UNetConditionModel priorUnet, UNetConditionModel decoderUnet, AutoEncoderModel decoderVqgan, AutoEncoderModel imageEncoder, MemoryModeType memoryMode, ILogger logger = default)
: base(priorUnet, decoderUnet, decoderVqgan, imageEncoder, memoryMode, logger) { }


/// <summary>
/// Gets the type of the diffuser.
/// </summary>
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;


/// <summary>
/// Gets the timesteps.
/// </summary>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <returns></returns>
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
{
if (!options.Timesteps.IsNullOrEmpty())
return options.Timesteps;

return scheduler.Timesteps;
}


/// <summary>
/// Encodes the image.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
/// <returns></returns>
protected override async Task<DenseTensor<float>> EncodeImageAsync(PromptOptions prompt, bool performGuidance)
{
var metadata = await _vaeEncoder.GetMetadataAsync();
var imageTensor = prompt.InputImage.GetClipImageFeatureTensor();
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(imageTensor);
inferenceParameters.AddOutputBuffer(new[] { 1, ClipImageChannels });

var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
{
// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _vaeEncoder.UnloadAsync();

var image_embeds = result.ToDenseTensor(new[] { 1, 1, ClipImageChannels });
if (performGuidance)
return new DenseTensor<float>(image_embeds.Dimensions).Concatenate(image_embeds);

return image_embeds;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,9 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de


/// <summary>
/// Prepares the decoder latents.
/// Gets the clip image channels.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <param name="timesteps">The timesteps.</param>
/// <param name="priorLatents">The prior latents.</param>
/// <returns></returns>
protected abstract Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents);
protected int ClipImageChannels => _clipImageChannels;


/// <summary>
Expand Down Expand Up @@ -135,6 +129,8 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(PromptOptions prompt,
// Create latent sample
var latents = await PrepareLatentsAsync(prompt, schedulerOptions, scheduler, timesteps);

var encodedImage = await EncodeImageAsync(prompt, performGuidance);

// Get Model metadata
var metadata = await _unet.GetMetadataAsync();

Expand All @@ -150,14 +146,13 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(PromptOptions prompt,
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
var imageEmbeds = new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, _clipImageChannels });
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddInputTensor(timestepTensor);
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
inferenceParameters.AddInputTensor(imageEmbeds);
inferenceParameters.AddInputTensor(encodedImage);
inferenceParameters.AddOutputBuffer(inputTensor.Dimensions);

var results = await _unet.RunInferenceAsync(inferenceParameters);
Expand Down Expand Up @@ -187,6 +182,8 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(PromptOptions prompt,
}




/// <summary>
/// Run the Decoder UNET diffusion
/// </summary>
Expand Down Expand Up @@ -297,6 +294,59 @@ protected override async Task<DenseTensor<float>> DecodeLatentsAsync(PromptOptio
}


/// <summary>
/// Prepares the input latents.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <param name="timesteps">The timesteps.</param>
/// <returns></returns>
protected override Task<DenseTensor<float>> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
var latents = scheduler.CreateRandomSample(new[]
{
1, 16,
(int)Math.Ceiling(options.Height / ResolutionMultiple),
(int)Math.Ceiling(options.Width / ResolutionMultiple)
}, scheduler.InitNoiseSigma);
return Task.FromResult(latents);
}


/// <summary>
/// Prepares the decoder latents.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <param name="timesteps">The timesteps.</param>
/// <param name="priorLatents">The prior latents.</param>
/// <returns></returns>
protected virtual Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents)
{
var latents = scheduler.CreateRandomSample(new[]
{
1, 4,
(int)(priorLatents.Dimensions[2] * LatentDimScale),
(int)(priorLatents.Dimensions[3] * LatentDimScale)
}, scheduler.InitNoiseSigma);
return Task.FromResult(latents);
}


/// <summary>
/// Encodes the image.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
/// <returns></returns>
protected virtual Task<DenseTensor<float>> EncodeImageAsync(PromptOptions prompt, bool performGuidance)
{
return Task.FromResult(new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, _clipImageChannels }));
}


/// <summary>
/// Creates the timestep tensor.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,6 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
}


protected override Task<DenseTensor<float>> PrepareLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
var latents = scheduler.CreateRandomSample(new[]
{
1, 16,
(int)Math.Ceiling(options.Height / ResolutionMultiple),
(int)Math.Ceiling(options.Width / ResolutionMultiple)
}, scheduler.InitNoiseSigma);
return Task.FromResult(latents);
}


protected override Task<DenseTensor<float>> PrepareDecoderLatentsAsync(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps, DenseTensor<float> priorLatents)
{
var latents = scheduler.CreateRandomSample(new[]
{
1, 4,
(int)(priorLatents.Dimensions[2] * LatentDimScale),
(int)(priorLatents.Dimensions[3] * LatentDimScale)
}, scheduler.InitNoiseSigma);
return Task.FromResult(latents);
}

}
}
4 changes: 3 additions & 1 deletion OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tok
_decoderUnet = decoderUnet;
_supportedDiffusers = diffusers ?? new List<DiffuserType>
{
DiffuserType.TextToImage
DiffuserType.TextToImage,
DiffuserType.ImageToImage
};
_supportedSchedulers = new List<SchedulerType>
{
Expand Down Expand Up @@ -99,6 +100,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
return diffuserType switch
{
DiffuserType.TextToImage => new TextDiffuser(_unet, _decoderUnet, _vaeDecoder, _pipelineOptions.MemoryMode, _logger),
DiffuserType.ImageToImage => new ImageDiffuser(_unet, _decoderUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
_ => throw new NotImplementedException()
};
}
Expand Down

0 comments on commit e3ce080

Please sign in to comment.