Skip to content

Commit

Permalink
Add support for tokenizer AttentionMask
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed Apr 25, 2024
1 parent b6920e0 commit 995c9eb
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 50 deletions.
16 changes: 9 additions & 7 deletions OnnxStack.Console/Examples/StableCascadeExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public sealed class StableCascadeExample : IExampleRunner
public StableCascadeExample(StableDiffusionConfig configuration)
{
_configuration = configuration;
_outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(StableDiffusionExample));
_outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(StableCascadeExample));
}

public int Index => 20;
Expand All @@ -31,14 +31,14 @@ public async Task RunAsync()
Directory.CreateDirectory(_outputDirectory);


var prompt = "cat wearing a hat";
var prompt = "photo of a cat";
var promptOptions = new PromptOptions
{
Prompt = prompt
};

// Create Pipeline
var pipeline = StableCascadePipeline.CreatePipeline("D:\\Repositories\\stable-cascade-onnx\\unoptimized", memoryMode: StableDiffusion.Enums.MemoryModeType.Minimum);
var pipeline = StableCascadePipeline.CreatePipeline("D:\\Repositories\\stable-cascade-onnx", memoryMode: StableDiffusion.Enums.MemoryModeType.Minimum);

// Preload Models (optional)
await pipeline.LoadAsync();
Expand All @@ -48,8 +48,8 @@ public async Task RunAsync()
var schedulerOptions = pipeline.DefaultSchedulerOptions with
{
SchedulerType = StableDiffusion.Enums.SchedulerType.DDPM,
GuidanceScale = 5f,
InferenceSteps = 10,
GuidanceScale =4f,
InferenceSteps = 60,
Width = 1024,
Height = 1024
};
Expand All @@ -58,10 +58,12 @@ public async Task RunAsync()


// Run pipeline
var result = await pipeline.GenerateImageAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback);
var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback);

var image = new OnnxImage(result, ImageNormalizeType.ZeroToOne);

// Save Image File
await result.SaveAsync(Path.Combine(_outputDirectory, $"output.png"));
await image.SaveAsync(Path.Combine(_outputDirectory, $"output.png"));

await pipeline.UnloadAsync();

Expand Down
14 changes: 12 additions & 2 deletions OnnxStack.StableDiffusion/Common/PromptResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@ namespace OnnxStack.StableDiffusion.Common
{
public record PromptEmbeddingsResult(DenseTensor<float> PromptEmbeds, DenseTensor<float> PooledPromptEmbeds = default);

public record EncoderResult(float[] PromptEmbeds, float[] PooledPromptEmbeds);
public record EncoderResult(DenseTensor<float> PromptEmbeds, DenseTensor<float> PooledPromptEmbeds);

public record EmbedsResult(DenseTensor<float> PromptEmbeds, DenseTensor<float> PooledPromptEmbeds);
public record TokenizerResult
{
public TokenizerResult(long[] inputIds, long[] attentionMask)
{
InputIds = inputIds;
AttentionMask = attentionMask;
}

public long[] InputIds { get; set; }
public long[] AttentionMask { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
{
// Get Scheduler
using (var schedulerPrior = GetScheduler(schedulerOptions))
using (var schedulerDecoder = GetScheduler(schedulerOptions))
using (var schedulerDecoder = GetScheduler(schedulerOptions with{ InferenceSteps = 10, GuidanceScale = 0}))
{
//----------------------------------------------------
// Prior Unet
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
Expand Down Expand Up @@ -39,6 +40,9 @@ public TextDiffuser(UNetConditionModel priorUnet, UNetConditionModel decoderUnet
/// <returns></returns>
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
{
if (!options.Timesteps.IsNullOrEmpty())
return options.Timesteps;

return scheduler.Timesteps;
}

Expand Down
128 changes: 97 additions & 31 deletions OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tok
};
_supportedSchedulers = new List<SchedulerType>
{
SchedulerType.EulerAncestral
SchedulerType.DDPM
};
_defaultSchedulerOptions = defaultSchedulerOptions ?? new SchedulerOptions
{
InferenceSteps = 1,
GuidanceScale = 0f,
SchedulerType = SchedulerType.EulerAncestral
SchedulerType = SchedulerType.DDPM
};
}

Expand All @@ -76,12 +76,19 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
};
}


/// <summary>
/// Creates the embeds using Tokenizer2 and TextEncoder2
/// </summary>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="isGuidanceEnabled">if set to <c>true</c> [is guidance enabled].</param>
/// <returns></returns>
protected override async Task<PromptEmbeddingsResult> CreatePromptEmbedsAsync(PromptOptions promptOptions, bool isGuidanceEnabled)
{
/// Tokenize Prompt and NegativePrompt with Tokenizer2
var promptTokens = await DecodePromptTextAsync(promptOptions.Prompt);
var negativePromptTokens = await DecodePromptTextAsync(promptOptions.NegativePrompt);
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
var promptTokens = await DecodeTextAsLongAsync(promptOptions.Prompt);
var negativePromptTokens = await DecodeTextAsLongAsync(promptOptions.NegativePrompt);
var maxPromptTokenCount = Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);

// Generate embeds for tokens
var promptEmbeddings = await GenerateEmbedsAsync(promptTokens, maxPromptTokenCount);
Expand All @@ -103,50 +110,109 @@ protected override async Task<PromptEmbeddingsResult> CreatePromptEmbedsAsync(Pr
}


private async Task<EncoderResult> EncodeTokensAsync(int[] tokenizedInput)
/// <summary>
/// Decodes the text as tokens
/// </summary>
/// <param name="inputText">The input text.</param>
/// <returns></returns>
private async Task<TokenizerResult> DecodeTextAsLongAsync(string inputText)
{
if (string.IsNullOrEmpty(inputText))
return new TokenizerResult(Array.Empty<long>(), Array.Empty<long>());

var metadata = await _tokenizer.GetMetadataAsync();
var inputTensor = new DenseTensor<string>(new string[] { inputText }, new int[] { 1 });
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddOutputBuffer();
inferenceParameters.AddOutputBuffer();

using (var results = _tokenizer.RunInference(inferenceParameters))
{
return new TokenizerResult(results[0].ToArray<long>(), results[1].ToArray<long>());
}
}
}


/// <summary>
/// Encodes the tokens.
/// </summary>
/// <param name="tokenizedInput">The tokenized input.</param>
/// <returns></returns>
private async Task<EncoderResult> EncodeTokensAsync(TokenizerResult tokenizedInput)
{
var inputDim = new[] { 1, tokenizedInput.Length };
var promptOutputDim = new[] { 1, tokenizedInput.Length, _tokenizer.TokenizerLength };
var pooledOutputDim = new[] { 1, _tokenizer.TokenizerLength };
var metadata = await _textEncoder.GetMetadataAsync();
var inputTensor = new DenseTensor<int>(tokenizedInput, inputDim);
var inputTensor = new DenseTensor<long>(tokenizedInput.InputIds, new[] { 1, tokenizedInput.InputIds.Length });
var attentionTensor = new DenseTensor<long>(tokenizedInput.AttentionMask, new[] { 1, tokenizedInput.AttentionMask.Length });
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddOutputBuffer(pooledOutputDim);
inferenceParameters.AddOutputBuffer(promptOutputDim);
inferenceParameters.AddInputTensor(attentionTensor);

// text_embeds + hidden_states.32
inferenceParameters.AddOutputBuffer(new[] { 1, _tokenizer.TokenizerLength });
inferenceParameters.AddOutputBuffer(metadata.Outputs.Count - 1, new[] { 1, tokenizedInput.InputIds.Length, _tokenizer.TokenizerLength });

var results = await _textEncoder.RunInferenceAsync(inferenceParameters);
return new EncoderResult(results.Last().ToArray(), results.First().ToArray());
var promptEmbeds = results.Last().ToDenseTensor();
var promptEmbedsPooled = results.First().ToDenseTensor();
return new EncoderResult(promptEmbeds, promptEmbedsPooled);
}
}


private async Task<EmbedsResult> GenerateEmbedsAsync(int[] inputTokens, int minimumLength)
/// <summary>
/// Generates the embeds.
/// </summary>
/// <param name="inputTokens">The input tokens.</param>
/// <param name="minimumLength">The minimum length.</param>
/// <returns></returns>
private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(TokenizerResult inputTokens, int minimumLength)
{
// If less than minimumLength pad with blank tokens
if (inputTokens.Length < minimumLength)
inputTokens = PadWithBlankTokens(inputTokens, minimumLength).ToArray();

// The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate1
var embeddings = new List<float>();
var pooledEmbeds = new List<float>();
foreach (var tokenBatch in inputTokens.Batch(_tokenizer.TokenizerLimit))
if (inputTokens.InputIds.Length < minimumLength)
{
var tokens = PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit);
var result = await EncodeTokensAsync(tokens.ToArray());
inputTokens.InputIds = PadWithBlankTokens(inputTokens.InputIds, minimumLength, _tokenizer.PadTokenId).ToArray();
inputTokens.AttentionMask = PadWithBlankTokens(inputTokens.AttentionMask, minimumLength, 1).ToArray();
}

embeddings.AddRange(result.PromptEmbeds);
pooledEmbeds.AddRange(result.PooledPromptEmbeds);
// The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate
var tokenBatches = new List<long[]>();
var attentionBatches = new List<long[]>();
foreach (var tokenBatch in inputTokens.InputIds.Batch(_tokenizer.TokenizerLimit))
tokenBatches.Add(PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit, _tokenizer.PadTokenId).ToArray());
foreach (var attentionBatch in inputTokens.AttentionMask.Batch(_tokenizer.TokenizerLimit))
attentionBatches.Add(PadWithBlankTokens(attentionBatch, _tokenizer.TokenizerLimit, 1).ToArray());

var promptEmbeddings = new List<float>();
var pooledPromptEmbeddings = new List<float>();
for (int i = 0; i < tokenBatches.Count; i++)
{
var result = await EncodeTokensAsync(new TokenizerResult(tokenBatches[i], attentionBatches[i]));
promptEmbeddings.AddRange(result.PromptEmbeds);
pooledPromptEmbeddings.AddRange(result.PooledPromptEmbeds);
}

var embeddingsDim = new[] { 1, embeddings.Count / _tokenizer.TokenizerLength, _tokenizer.TokenizerLength };
var promptTensor = new DenseTensor<float>(embeddings.ToArray(), embeddingsDim);
var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), new[] { 1, promptEmbeddings.Count / _tokenizer.TokenizerLength, _tokenizer.TokenizerLength });
var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.ToArray(), new[] { 1, tokenBatches.Count, 1280 });
return new PromptEmbeddingsResult(promptTensor, pooledTensor);
}


//TODO: Pooled embeds do not support more than 77 tokens, just grab first set
var pooledDim = new[] { 1, 1, _tokenizer.TokenizerLength };
var pooledTensor = new DenseTensor<float>(pooledEmbeds.Take(_tokenizer.TokenizerLength).ToArray(), pooledDim);
return new EmbedsResult(promptTensor, pooledTensor);
/// <summary>
/// Pads the input array with blank tokens.
/// </summary>
/// <param name="inputs">The inputs.</param>
/// <param name="requiredLength">Length of the required.</param>
/// <returns></returns>
private IEnumerable<long> PadWithBlankTokens(IEnumerable<long> inputs, int requiredLength, int padTokenId)
{
var count = inputs.Count();
if (requiredLength > count)
return inputs.Concat(Enumerable.Repeat((long)padTokenId, requiredLength - count));
return inputs;
}


Expand Down
17 changes: 8 additions & 9 deletions OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,20 +240,19 @@ private async Task<long[]> DecodeTextAsLongAsync(string inputText)
/// <returns></returns>
private async Task<EncoderResult> EncodeTokensAsync(long[] tokenizedInput)
{
var inputDim = new[] { 1, tokenizedInput.Length };
var promptOutputDim = new[] { 1, tokenizedInput.Length, _tokenizer2.TokenizerLength };
var pooledOutputDim = new[] { 1, _tokenizer2.TokenizerLength };
var metadata = await _textEncoder2.GetMetadataAsync();
var inputTensor = new DenseTensor<long>(tokenizedInput, inputDim);
var inputTensor = new DenseTensor<long>(tokenizedInput, new[] { 1, tokenizedInput.Length });
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
int hiddenStateIndex = metadata.Outputs.Count - 2;
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddOutputBuffer(pooledOutputDim);
inferenceParameters.AddOutputBuffer(hiddenStateIndex, promptOutputDim);
inferenceParameters.AddOutputBuffer(new[] { 1, _tokenizer2.TokenizerLength });
inferenceParameters.AddOutputBuffer(hiddenStateIndex, new[] { 1, tokenizedInput.Length, _tokenizer2.TokenizerLength });

var results = await _textEncoder2.RunInferenceAsync(inferenceParameters);
return new EncoderResult(results.Last().ToArray(), results.First().ToArray());
var promptEmbeds = results.Last().ToDenseTensor();
var promptEmbedsPooled = results.First().ToDenseTensor();
return new EncoderResult(promptEmbeds, promptEmbedsPooled);
}
}

Expand All @@ -264,7 +263,7 @@ private async Task<EncoderResult> EncodeTokensAsync(long[] tokenizedInput)
/// <param name="inputTokens">The input tokens.</param>
/// <param name="minimumLength">The minimum length.</param>
/// <returns></returns>
private async Task<EmbedsResult> GenerateEmbedsAsync(long[] inputTokens, int minimumLength)
private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(long[] inputTokens, int minimumLength)
{
// If less than minimumLength pad with blank tokens
if (inputTokens.Length < minimumLength)
Expand All @@ -288,7 +287,7 @@ private async Task<EmbedsResult> GenerateEmbedsAsync(long[] inputTokens, int min
//TODO: Pooled embeds do not support more than 77 tokens, just grab first set
var pooledDim = new[] { 1, _tokenizer2.TokenizerLength };
var pooledTensor = new DenseTensor<float>(pooledEmbeds.Take(_tokenizer2.TokenizerLength).ToArray(), pooledDim);
return new EmbedsResult(promptTensor, pooledTensor);
return new PromptEmbeddingsResult(promptTensor, pooledTensor);
}


Expand Down

0 comments on commit 995c9eb

Please sign in to comment.