Skip to content

Commit

Permalink
Fix up Tokenizer/TextEncoder inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed Apr 25, 2024
1 parent 9b3acf4 commit 5da385d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 58 deletions.
77 changes: 50 additions & 27 deletions OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ protected virtual async Task<PromptEmbeddingsResult> CreatePromptEmbedsAsync(Pro
// Tokenize Prompt and NegativePrompt
var promptTokens = await DecodePromptTextAsync(promptOptions.Prompt);
var negativePromptTokens = await DecodePromptTextAsync(promptOptions.NegativePrompt);
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
var maxPromptTokenCount = Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);

// Generate embeds for tokens
var promptEmbeddings = await GeneratePromptEmbedsAsync(promptTokens, maxPromptTokenCount);
Expand All @@ -512,9 +512,11 @@ protected virtual async Task<PromptEmbeddingsResult> CreatePromptEmbedsAsync(Pro
}

if (isGuidanceEnabled)
return new PromptEmbeddingsResult(negativePromptEmbeddings.Concatenate(promptEmbeddings));
return new PromptEmbeddingsResult(
negativePromptEmbeddings.PromptEmbeds.Concatenate(promptEmbeddings.PromptEmbeds),
negativePromptEmbeddings.PooledPromptEmbeds.Concatenate(promptEmbeddings.PooledPromptEmbeds));

return new PromptEmbeddingsResult(promptEmbeddings);
return new PromptEmbeddingsResult(promptEmbeddings.PromptEmbeds, promptEmbeddings.PooledPromptEmbeds);
}


Expand All @@ -523,22 +525,21 @@ protected virtual async Task<PromptEmbeddingsResult> CreatePromptEmbedsAsync(Pro
/// </summary>
/// <param name="inputText">The input text.</param>
/// <returns></returns>
protected async Task<int[]> DecodePromptTextAsync(string inputText)
protected async Task<TokenizerResult> DecodePromptTextAsync(string inputText)
{
if (string.IsNullOrEmpty(inputText))
return Array.Empty<int>();
return new TokenizerResult(Array.Empty<long>(), Array.Empty<long>());

var metadata = await _tokenizer.GetMetadataAsync();
var inputTensor = new DenseTensor<string>(new string[] { inputText }, new int[] { 1 });
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddOutputBuffer();

inferenceParameters.AddOutputBuffer();
using (var results = _tokenizer.RunInference(inferenceParameters))
{
var resultData = results.First().ToArray<long>();
return Array.ConvertAll(resultData, Convert.ToInt32);
return new TokenizerResult(results[0].ToArray<long>(), results[1].ToArray<long>());
}
}
}
Expand All @@ -549,21 +550,21 @@ protected async Task<int[]> DecodePromptTextAsync(string inputText)
/// </summary>
/// <param name="tokenizedInput">The tokenized input.</param>
/// <returns></returns>
protected async Task<float[]> EncodePromptTokensAsync(int[] tokenizedInput)
protected async Task<EncoderResult> EncodePromptTokensAsync(TokenizerResult tokenizedInput)
{
var inputDim = new[] { 1, tokenizedInput.Length };
var outputDim = new[] { 1, tokenizedInput.Length, _tokenizer.TokenizerLength };
var metadata = await _textEncoder.GetMetadataAsync();
var inputTensor = new DenseTensor<int>(tokenizedInput, inputDim);
var inputTensor = new DenseTensor<int>(tokenizedInput.InputIds.ToInt(), new[] { 1, tokenizedInput.InputIds.Length });
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddOutputBuffer(outputDim);
inferenceParameters.AddOutputBuffer(new[] { 1, tokenizedInput.InputIds.Length, _tokenizer.TokenizerLength });
inferenceParameters.AddOutputBuffer(new int[] { 1, _tokenizer.TokenizerLength });

var results = await _textEncoder.RunInferenceAsync(inferenceParameters);
using (var result = results.First())
using (var promptEmbeds = results.Last())
using (var promptEmbedsPooled = results.First())
{
return result.ToArray();
return new EncoderResult(promptEmbeds.ToDenseTensor(), promptEmbedsPooled.ToDenseTensor());
}
}
}
Expand All @@ -575,22 +576,44 @@ protected async Task<float[]> EncodePromptTokensAsync(int[] tokenizedInput)
/// <param name="inputTokens">The input tokens.</param>
/// <param name="minimumLength">The minimum length.</param>
/// <returns></returns>
protected async Task<DenseTensor<float>> GeneratePromptEmbedsAsync(int[] inputTokens, int minimumLength)
protected async Task<PromptEmbeddingsResult> GeneratePromptEmbedsAsync(TokenizerResult inputTokens, int minimumLength)
{
// If less than minimumLength pad with blank tokens
if (inputTokens.Length < minimumLength)
inputTokens = PadWithBlankTokens(inputTokens, minimumLength).ToArray();
if (inputTokens.InputIds.Length < minimumLength)
{
inputTokens.InputIds = PadWithBlankTokens(inputTokens.InputIds, minimumLength, _tokenizer.PadTokenId).ToArray();
inputTokens.AttentionMask = PadWithBlankTokens(inputTokens.AttentionMask, minimumLength, 1).ToArray();
}

// The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate1
var tokenBatches = new List<long[]>();
var attentionBatches = new List<long[]>();
foreach (var tokenBatch in inputTokens.InputIds.Batch(_tokenizer.TokenizerLimit))
tokenBatches.Add(PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit, _tokenizer.PadTokenId).ToArray());
foreach (var attentionBatch in inputTokens.AttentionMask.Batch(_tokenizer.TokenizerLimit))
attentionBatches.Add(PadWithBlankTokens(attentionBatch, _tokenizer.TokenizerLimit, 1).ToArray());

// The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate
var embeddings = new List<float>();
foreach (var tokenBatch in inputTokens.Batch(_tokenizer.TokenizerLimit))

var promptEmbeddings = new List<float>();
var pooledPromptEmbeddings = new List<float>();
for (int i = 0; i < tokenBatches.Count; i++)
{
var tokens = PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit);
embeddings.AddRange(await EncodePromptTokensAsync(tokens.ToArray()));
var result = await EncodePromptTokensAsync(new TokenizerResult(tokenBatches[i], attentionBatches[i]));
promptEmbeddings.AddRange(result.PromptEmbeds);
pooledPromptEmbeddings.AddRange(result.PooledPromptEmbeds);
}

var dim = new[] { 1, embeddings.Count / _tokenizer.TokenizerLength, _tokenizer.TokenizerLength };
return new DenseTensor<float>(embeddings.ToArray(), dim);

//var embeddingsDim = new[] { 1, promptEmbeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength };
//var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), embeddingsDim);

////TODO: Pooled embeds do not support more than 77 tokens, just grab first set
//var pooledDim = new[] { 1, _tokenizer2.TokenizerLength };
//var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.Take(_tokenizer2.TokenizerLength).ToArray(), pooledDim);

var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), new[] { 1, promptEmbeddings.Count / _tokenizer.TokenizerLength, _tokenizer.TokenizerLength });
var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.ToArray(), new[] { 1, _tokenizer.TokenizerLimit, _tokenizer.TokenizerLength });
return new PromptEmbeddingsResult(promptTensor, pooledTensor);
}


Expand All @@ -600,11 +623,11 @@ protected async Task<DenseTensor<float>> GeneratePromptEmbedsAsync(int[] inputTo
/// <param name="inputs">The inputs.</param>
/// <param name="requiredLength">The the required length of the returned array.</param>
/// <returns></returns>
protected IEnumerable<int> PadWithBlankTokens(IEnumerable<int> inputs, int requiredLength)
protected IEnumerable<long> PadWithBlankTokens(IEnumerable<long> inputs, int requiredLength, int padTokenId)
{
var count = inputs.Count();
if (requiredLength > count)
return inputs.Concat(Enumerable.Repeat(_tokenizer.PadTokenId, requiredLength - count));
return inputs.Concat(Enumerable.Repeat((long)padTokenId, requiredLength - count));
return inputs;
}

Expand Down
74 changes: 43 additions & 31 deletions OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsTwoAsync(PromptOptions pr
/// Tokenize Prompt and NegativePrompt with Tokenizer2
var promptTokens = await DecodeTextAsLongAsync(promptOptions.Prompt);
var negativePromptTokens = await DecodeTextAsLongAsync(promptOptions.NegativePrompt);
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
var maxPromptTokenCount = Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);

// Generate embeds for tokens
var promptEmbeddings = await GenerateEmbedsAsync(promptTokens, maxPromptTokenCount);
Expand Down Expand Up @@ -174,7 +174,7 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
// Tokenize Prompt and NegativePrompt
var promptTokens = await DecodePromptTextAsync(promptOptions.Prompt);
var negativePromptTokens = await DecodePromptTextAsync(promptOptions.NegativePrompt);
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
var maxPromptTokenCount = Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);

// Generate embeds for tokens
var promptEmbeddings = await GeneratePromptEmbedsAsync(promptTokens, maxPromptTokenCount);
Expand All @@ -188,8 +188,8 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
var dualPromptEmbeddings = await GenerateEmbedsAsync(dualPromptTokens, maxPromptTokenCount);
var dualNegativePromptEmbeddings = await GenerateEmbedsAsync(dualNegativePromptTokens, maxPromptTokenCount);

var dualPrompt = promptEmbeddings.Concatenate(dualPromptEmbeddings.PromptEmbeds, 2);
var dualNegativePrompt = negativePromptEmbeddings.Concatenate(dualNegativePromptEmbeddings.PromptEmbeds, 2);
var dualPrompt = promptEmbeddings.PromptEmbeds.Concatenate(dualPromptEmbeddings.PromptEmbeds, 2);
var dualNegativePrompt = negativePromptEmbeddings.PromptEmbeds.Concatenate(dualNegativePromptEmbeddings.PromptEmbeds, 2);
var pooledPromptEmbeds = dualPromptEmbeddings.PooledPromptEmbeds;
var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings.PooledPromptEmbeds;

Expand All @@ -212,22 +212,21 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
/// </summary>
/// <param name="inputText">The input text.</param>
/// <returns></returns>
private async Task<long[]> DecodeTextAsLongAsync(string inputText)
private async Task<TokenizerResult> DecodeTextAsLongAsync(string inputText)
{
if (string.IsNullOrEmpty(inputText))
return Array.Empty<long>();
return new TokenizerResult(Array.Empty<long>(), Array.Empty<long>());

var metadata = await _tokenizer2.GetMetadataAsync();
var inputTensor = new DenseTensor<string>(new string[] { inputText }, new int[] { 1 });
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddOutputBuffer();

using (var results = _tokenizer2.RunInference(inferenceParameters))
inferenceParameters.AddOutputBuffer();
using (var results = _tokenizer.RunInference(inferenceParameters))
{
var resultData = results.First().ToArray<long>();
return resultData;
return new TokenizerResult(results[0].ToArray<long>(), results[1].ToArray<long>());
}
}
}
Expand All @@ -238,16 +237,16 @@ private async Task<long[]> DecodeTextAsLongAsync(string inputText)
/// </summary>
/// <param name="tokenizedInput">The tokenized input.</param>
/// <returns></returns>
private async Task<EncoderResult> EncodeTokensAsync(long[] tokenizedInput)
private async Task<EncoderResult> EncodeTokensAsync(TokenizerResult tokenizedInput)
{
var metadata = await _textEncoder2.GetMetadataAsync();
var inputTensor = new DenseTensor<long>(tokenizedInput, new[] { 1, tokenizedInput.Length });
var inputTensor = new DenseTensor<long>(tokenizedInput.InputIds, new[] { 1, tokenizedInput.InputIds.Length });
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
int hiddenStateIndex = metadata.Outputs.Count - 2;
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddOutputBuffer(new[] { 1, _tokenizer2.TokenizerLength });
inferenceParameters.AddOutputBuffer(hiddenStateIndex, new[] { 1, tokenizedInput.Length, _tokenizer2.TokenizerLength });
inferenceParameters.AddOutputBuffer(hiddenStateIndex, new[] { 1, tokenizedInput.InputIds.Length, _tokenizer2.TokenizerLength });

var results = await _textEncoder2.RunInferenceAsync(inferenceParameters);
var promptEmbeds = results.Last().ToDenseTensor();
Expand All @@ -263,30 +262,43 @@ private async Task<EncoderResult> EncodeTokensAsync(long[] tokenizedInput)
/// <param name="inputTokens">The input tokens.</param>
/// <param name="minimumLength">The minimum length.</param>
/// <returns></returns>
private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(long[] inputTokens, int minimumLength)
private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(TokenizerResult inputTokens, int minimumLength)
{
// If less than minimumLength pad with blank tokens
if (inputTokens.Length < minimumLength)
inputTokens = PadWithBlankTokens(inputTokens, minimumLength).ToArray();
if (inputTokens.InputIds.Length < minimumLength)
{
inputTokens.InputIds = PadWithBlankTokens(inputTokens.InputIds, minimumLength, _tokenizer.PadTokenId).ToArray();
inputTokens.AttentionMask = PadWithBlankTokens(inputTokens.AttentionMask, minimumLength, 1).ToArray();
}

// The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate1
var embeddings = new List<float>();
var pooledEmbeds = new List<float>();
foreach (var tokenBatch in inputTokens.Batch(_tokenizer2.TokenizerLimit))
{
var tokens = PadWithBlankTokens(tokenBatch, _tokenizer2.TokenizerLimit);
var result = await EncodeTokensAsync(tokens.ToArray());
var tokenBatches = new List<long[]>();
var attentionBatches = new List<long[]>();
foreach (var tokenBatch in inputTokens.InputIds.Batch(_tokenizer.TokenizerLimit))
tokenBatches.Add(PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit, _tokenizer.PadTokenId).ToArray());
foreach (var attentionBatch in inputTokens.AttentionMask.Batch(_tokenizer.TokenizerLimit))
attentionBatches.Add(PadWithBlankTokens(attentionBatch, _tokenizer.TokenizerLimit, 1).ToArray());


embeddings.AddRange(result.PromptEmbeds);
pooledEmbeds.AddRange(result.PooledPromptEmbeds);
var promptEmbeddings = new List<float>();
var pooledPromptEmbeddings = new List<float>();
for (int i = 0; i < tokenBatches.Count; i++)
{
var result = await EncodeTokensAsync(new TokenizerResult(tokenBatches[i], attentionBatches[i]));
promptEmbeddings.AddRange(result.PromptEmbeds);
pooledPromptEmbeddings.AddRange(result.PooledPromptEmbeds);
}


//var embeddingsDim = new[] { 1, promptEmbeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength };
//var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), embeddingsDim);

var embeddingsDim = new[] { 1, embeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength };
var promptTensor = new DenseTensor<float>(embeddings.ToArray(), embeddingsDim);
////TODO: Pooled embeds do not support more than 77 tokens, just grab first set
//var pooledDim = new[] { 1, _tokenizer2.TokenizerLength };
//var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.Take(_tokenizer2.TokenizerLength).ToArray(), pooledDim);

//TODO: Pooled embeds do not support more than 77 tokens, just grab first set
var pooledDim = new[] { 1, _tokenizer2.TokenizerLength };
var pooledTensor = new DenseTensor<float>(pooledEmbeds.Take(_tokenizer2.TokenizerLength).ToArray(), pooledDim);
var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), new[] { 1, promptEmbeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength });
var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.ToArray(), new[] { 1, pooledPromptEmbeddings.Count });
return new PromptEmbeddingsResult(promptTensor, pooledTensor);
}

Expand All @@ -297,11 +309,11 @@ private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(long[] inputToken
/// <param name="inputs">The inputs.</param>
/// <param name="requiredLength">Length of the required.</param>
/// <returns></returns>
private IEnumerable<long> PadWithBlankTokens(IEnumerable<long> inputs, int requiredLength)
private IEnumerable<long> PadWithBlankTokens(IEnumerable<long> inputs, int requiredLength, int padTokenId)
{
var count = inputs.Count();
if (requiredLength > count)
return inputs.Concat(Enumerable.Repeat((long)_tokenizer.PadTokenId, requiredLength - count));
return inputs.Concat(Enumerable.Repeat((long)padTokenId, requiredLength - count));
return inputs;
}

Expand Down

0 comments on commit 5da385d

Please sign in to comment.