Skip to content

Commit

Permalink
Split Prior and Decoder logic
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed Apr 25, 2024
1 parent 8d5575a commit 9b3acf4
Showing 1 changed file with 53 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.Core.Image;
using OnnxStack.Core.Model;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
Expand Down Expand Up @@ -53,40 +52,48 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
/// <returns></returns>
public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
// Get Scheduler
using (var schedulerPrior = GetScheduler(schedulerOptions))
using (var schedulerDecoder = GetScheduler(schedulerOptions with{ InferenceSteps = 10, GuidanceScale = 0}))
{
//----------------------------------------------------
// Prior Unet
//====================================================
// Prior Unet
var latentsPrior = await DiffusePriorAsync(schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);

// Decoder Unet
var schedulerOptionsDecoder = schedulerOptions with { InferenceSteps = 10, GuidanceScale = 0 };
var latents = await DiffuseDecodeAsync(latentsPrior, schedulerOptionsDecoder, promptEmbeddings, performGuidance, progressCallback, cancellationToken);

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
}


protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
using (var scheduler = GetScheduler(schedulerOptions))
{
// Get timesteps
var timestepsPrior = GetTimesteps(schedulerOptions, schedulerPrior);
var timesteps = GetTimesteps(schedulerOptions, scheduler);

// Create latent sample
var latentsPrior = schedulerPrior.CreateRandomSample(new[] { 1, 16, (int)Math.Ceiling(schedulerOptions.Height / 42.67f), (int)Math.Ceiling(schedulerOptions.Width / 42.67f) }, schedulerPrior.InitNoiseSigma);
var latents = scheduler.CreateRandomSample(new[] { 1, 16, (int)Math.Ceiling(schedulerOptions.Height / 42.67f), (int)Math.Ceiling(schedulerOptions.Width / 42.67f) }, scheduler.InitNoiseSigma);

// Get Model metadata
var metadataPrior = await _unet.GetMetadataAsync();
var metadata = await _unet.GetMetadataAsync();

// Loop though the timesteps
var stepPrior = 0;
foreach (var timestep in timestepsPrior)
var step = 0;
foreach (var timestep in timesteps)
{
stepPrior++;
step++;
var stepTime = Stopwatch.GetTimestamp();
cancellationToken.ThrowIfCancellationRequested();

// Create input tensor.
var inputLatent = performGuidance ? latentsPrior.Repeat(2) : latentsPrior;
var inputTensor = schedulerPrior.ScaleInput(inputLatent, timestep);
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
var imageEmbeds = new DenseTensor<float>(performGuidance ? new[] { 2, 1, 768 } : new[] { 1, 1, 768 });
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
var imageEmbeds = new DenseTensor<float>(new[] { performGuidance ? 2 : 1, 1, 768 });

var outputChannels = performGuidance ? 2 : 1;
var outputDimension = inputTensor.Dimensions.ToArray(); //schedulerOptions.GetScaledDimension(outputChannels);
using (var inferenceParameters = new OnnxInferenceParameters(metadataPrior))
var outputDimension = inputTensor.Dimensions.ToArray();
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddInputTensor(timestepTensor);
Expand All @@ -105,58 +112,57 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);

// Scheduler Step
latentsPrior = schedulerPrior.Step(noisePred, timestep, latentsPrior).Result;
latents = scheduler.Step(noisePred, timestep, latents).Result;
}
}

ReportProgress(progressCallback, stepPrior, timestepsPrior.Count, latentsPrior);
_logger?.LogEnd(LogLevel.Debug, $"Step {stepPrior}/{timestepsPrior.Count}", stepTime);
ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd(LogLevel.Debug, $"Prior Step {step}/{timesteps.Count}", stepTime);
}

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _unet.UnloadAsync();

return latents;
}
}




//----------------------------------------------------
// Decoder Unet
//====================================================

protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> latentsPrior, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
using (var scheduler = GetScheduler(schedulerOptions))
{
// Get timesteps
var timestepsDecoder = GetTimesteps(schedulerOptions, schedulerDecoder);
var timesteps = GetTimesteps(schedulerOptions, scheduler);

// Create latent sample

var latentsDecoder = schedulerDecoder.CreateRandomSample(new[] { 1, 4, (int)(latentsPrior.Dimensions[2] * 10.67f), (int)(latentsPrior.Dimensions[3] * 10.67f) }, schedulerDecoder.InitNoiseSigma);
var latents = scheduler.CreateRandomSample(new[] { 1, 4, (int)(latentsPrior.Dimensions[2] * 10.67f), (int)(latentsPrior.Dimensions[3] * 10.67f) }, scheduler.InitNoiseSigma);

// Get Model metadata
var metadataDecoder = await _decoderUnet.GetMetadataAsync();
var metadata = await _decoderUnet.GetMetadataAsync();

var effnet = performGuidance
? latentsPrior
: latentsPrior.Concatenate(new DenseTensor<float>(latentsPrior.Dimensions));


// Loop though the timesteps
var stepDecoder = 0;
foreach (var timestep in timestepsDecoder)
var step = 0;
foreach (var timestep in timesteps)
{
stepDecoder++;
step++;
var stepTime = Stopwatch.GetTimestamp();
cancellationToken.ThrowIfCancellationRequested();

// Create input tensor.
var inputLatent = performGuidance ? latentsDecoder.Repeat(2) : latentsDecoder;
var inputTensor = schedulerDecoder.ScaleInput(inputLatent, timestep);
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);


var outputChannels = performGuidance ? 2 : 1;
var outputDimension = inputTensor.Dimensions.ToArray(); //schedulerOptions.GetScaledDimension(outputChannels);
using (var inferenceParameters = new OnnxInferenceParameters(metadataDecoder))
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddInputTensor(timestepTensor);
Expand All @@ -174,20 +180,19 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);

// Scheduler Step
latentsDecoder = schedulerDecoder.Step(noisePred, timestep, latentsDecoder).Result;
latents = scheduler.Step(noisePred, timestep, latents).Result;
}
}

ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd(LogLevel.Debug, $"Decoder Step {step}/{timesteps.Count}", stepTime);
}

var testlatentsPrior = new OnnxImage(latentsPrior);
var testlatentsDecoder = new OnnxImage(latentsDecoder);
await testlatentsPrior.SaveAsync("D:\\testlatentsPrior.png");
await testlatentsDecoder.SaveAsync("D:\\latentsDecoder.png");

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _unet.UnloadAsync();

// Decode Latents
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latentsDecoder);
return latents;
}
}

Expand Down

0 comments on commit 9b3acf4

Please sign in to comment.