Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Add LatentConsistency XL pipeline #65

Merged
merged 1 commit into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.Core.Config;
using OnnxStack.Core.Model;
using OnnxStack.Core.Services;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Helpers;
using SixLabors.ImageSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
{
public sealed class ImageDiffuser : LatentConsistencyXLDiffuser
{
/// <summary>
/// Initializes a new instance of the <see cref="ImageDiffuser"/> class.
/// </summary>
/// <param name="configuration">The configuration.</param>
/// <param name="onnxModelService">The onnx model service.</param>
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyXLDiffuser> logger)
: base(onnxModelService, promptService, logger)
{
}


/// <summary>
/// Gets the type of the diffuser.
/// </summary>
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;


/// <summary>
/// Gets the timesteps.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <returns></returns>
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
{
// Image2Image we narrow step the range by the Strength
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
return scheduler.Timesteps.Skip(start).ToList();
}


/// <summary>
/// Prepares the latents for inference.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <returns></returns>
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });

//TODO: Model Config, Channels
var outputDimension = options.GetScaledDimension();
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(imageTensor);
inferenceParameters.AddOutputBuffer(outputDimension);

var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
using (var result = results.First())
{
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
}
}
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.Core.Config;
using OnnxStack.Core.Model;
using OnnxStack.Core.Services;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Helpers;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.Processing;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
{
public sealed class InpaintLegacyDiffuser : LatentConsistencyXLDiffuser
{
/// <summary>
/// Initializes a new instance of the <see cref="InpaintLegacyDiffuser"/> class.
/// </summary>
/// <param name="configuration">The configuration.</param>
/// <param name="onnxModelService">The onnx model service.</param>
public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyXLDiffuser> logger)
: base(onnxModelService, promptService, logger)
{
}


/// <summary>
/// Gets the type of the diffuser.
/// </summary>
public override DiffuserType DiffuserType => DiffuserType.ImageInpaintLegacy;


/// <summary>
/// Runs the scheduler steps.
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="promptEmbeddings">The prompt embeddings.</param>
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
{
using (var scheduler = GetScheduler(schedulerOptions))
{
// Get timesteps
var timesteps = GetTimesteps(schedulerOptions, scheduler);

// Create latent sample
var latentsOriginal = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);

// Create masks sample
var maskImage = PrepareMask(modelOptions, promptOptions, schedulerOptions);

// Generate some noise
var noise = scheduler.CreateRandomSample(latentsOriginal.Dimensions);

// Add noise to original latent
var latents = scheduler.AddNoise(latentsOriginal, noise, timesteps);

// Get Model metadata
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);

// Get Time ids
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions, performGuidance);

// Loop though the timesteps
var step = 0;
foreach (var timestep in timesteps)
{
step++;
var stepTime = Stopwatch.GetTimestamp();
cancellationToken.ThrowIfCancellationRequested();

// Create input tensor.
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
var timestepTensor = CreateTimestepTensor(timestep);

var outputChannels = performGuidance ? 2 : 1;
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddInputTensor(timestepTensor);
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
inferenceParameters.AddInputTensor(addTimeIds);
inferenceParameters.AddOutputBuffer(outputDimension);

var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
using (var result = results.First())
{
var noisePred = result.ToDenseTensor();

// Perform guidance
if (performGuidance)
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);

// Scheduler Step
var steplatents = scheduler.Step(noisePred, timestep, latents).Result;

// Add noise to original latent
var initLatentsProper = scheduler.AddNoise(latentsOriginal, noise, new[] { timestep });

// Apply mask and combine
latents = ApplyMaskedLatents(steplatents, initLatentsProper, maskImage);
}
}

progressCallback?.Invoke(step, timesteps.Count);
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
}

// Decode Latents
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
}
}


/// <summary>
/// Gets the timesteps.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <returns></returns>
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
{
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
return scheduler.Timesteps.Skip(start).ToList();
}


/// <summary>
/// Prepares the latents for inference.
/// </summary>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <returns></returns>
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });

//TODO: Model Config, Channels
var outputDimensions = options.GetScaledDimension();
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(imageTensor);
inferenceParameters.AddOutputBuffer(outputDimensions);

var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
using (var result = results.First())
{
var outputResult = result.ToDenseTensor();
var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
return scaledSample;
}
}
}


/// <summary>
/// Prepares the mask.
/// </summary>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <returns></returns>
private DenseTensor<float> PrepareMask(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
{
using (var mask = promptOptions.InputImageMask.ToImage())
{
// Prepare the mask
int width = schedulerOptions.GetScaledWidth();
int height = schedulerOptions.GetScaledHeight();
mask.Mutate(x => x.Grayscale());
mask.Mutate(x => x.Resize(new Size(width, height), KnownResamplers.NearestNeighbor, true));
var maskTensor = new DenseTensor<float>(new[] { 1, 4, width, height });
mask.ProcessPixelRows(img =>
{
for (int x = 0; x < width; x++)
{
for (int y = 0; y < height; y++)
{
var pixelSpan = img.GetRowSpan(y);
var value = 1f - (pixelSpan[x].A / 255.0f);
maskTensor[0, 0, y, x] = value;
maskTensor[0, 1, y, x] = value; // Needed for shape only
maskTensor[0, 2, y, x] = value; // Needed for shape only
maskTensor[0, 3, y, x] = value; // Needed for shape only
}
}
});
return maskTensor;
}
}


/// <summary>
/// Applies the masked latents.
/// </summary>
/// <param name="latents">The latents.</param>
/// <param name="initLatentsProper">The initialize latents proper.</param>
/// <param name="mask">The mask.</param>
/// <returns></returns>
private DenseTensor<float> ApplyMaskedLatents(DenseTensor<float> latents, DenseTensor<float> initLatentsProper, DenseTensor<float> mask)
{
var result = new DenseTensor<float>(latents.Dimensions);
for (int i = 0; i < result.Length; i++)
{
float maskValue = mask.GetValue(i);
result.SetValue(i, initLatentsProper.GetValue(i) * maskValue + latents.GetValue(i) * (1f - maskValue));
}
return result;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core.Services;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Models;
using System.Collections.Generic;
using System.Threading.Tasks;
using System.Threading;
using System;
using OnnxStack.StableDiffusion.Schedulers.LatentConsistency;

namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
{
public abstract class LatentConsistencyXLDiffuser : StableDiffusionXLDiffuser
{
protected LatentConsistencyXLDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<StableDiffusionXLDiffuser> logger)
: base(onnxModelService, promptService, logger) { }


/// <summary>
/// Gets the type of the pipeline.
/// </summary>
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.LatentConsistencyXL;


/// <summary>
/// Runs the stable diffusion loop
/// </summary>
/// <param name="modelOptions"></param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="progressCallback"></param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
public override Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
{
// LCM does not support negative prompting
promptOptions.NegativePrompt = string.Empty;
return base.DiffuseAsync(modelOptions, promptOptions, schedulerOptions, progressCallback, cancellationToken);
}


/// <summary>
/// Runs the stable diffusion batch loop
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="batchOptions">The batch options.</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
public override IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default)
{
// LCM does not support negative prompting
promptOptions.NegativePrompt = string.Empty;
return base.DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken);
}


/// <summary>
/// Gets the scheduler.
/// </summary>
/// <param name="prompt"></param>
/// <param name="options">The options.</param>
/// <returns></returns>
protected override IScheduler GetScheduler(SchedulerOptions options)
{
return options.SchedulerType switch
{
SchedulerType.LCM => new LCMScheduler(options),
_ => default
};
}
}
}
Loading