Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 5da385d

Browse files
committed
Fix up Tokenizer/TextEncoder inputs
1 parent 9b3acf4 commit 5da385d

File tree

2 files changed

+93
-58
lines changed

2 files changed

+93
-58
lines changed

OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ protected virtual async Task<PromptEmbeddingsResult> CreatePromptEmbedsAsync(Pro
498498
// Tokenize Prompt and NegativePrompt
499499
var promptTokens = await DecodePromptTextAsync(promptOptions.Prompt);
500500
var negativePromptTokens = await DecodePromptTextAsync(promptOptions.NegativePrompt);
501-
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
501+
var maxPromptTokenCount = Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);
502502

503503
// Generate embeds for tokens
504504
var promptEmbeddings = await GeneratePromptEmbedsAsync(promptTokens, maxPromptTokenCount);
@@ -512,9 +512,11 @@ protected virtual async Task<PromptEmbeddingsResult> CreatePromptEmbedsAsync(Pro
512512
}
513513

514514
if (isGuidanceEnabled)
515-
return new PromptEmbeddingsResult(negativePromptEmbeddings.Concatenate(promptEmbeddings));
515+
return new PromptEmbeddingsResult(
516+
negativePromptEmbeddings.PromptEmbeds.Concatenate(promptEmbeddings.PromptEmbeds),
517+
negativePromptEmbeddings.PooledPromptEmbeds.Concatenate(promptEmbeddings.PooledPromptEmbeds));
516518

517-
return new PromptEmbeddingsResult(promptEmbeddings);
519+
return new PromptEmbeddingsResult(promptEmbeddings.PromptEmbeds, promptEmbeddings.PooledPromptEmbeds);
518520
}
519521

520522

@@ -523,22 +525,21 @@ protected virtual async Task<PromptEmbeddingsResult> CreatePromptEmbedsAsync(Pro
523525
/// </summary>
524526
/// <param name="inputText">The input text.</param>
525527
/// <returns></returns>
526-
protected async Task<int[]> DecodePromptTextAsync(string inputText)
528+
protected async Task<TokenizerResult> DecodePromptTextAsync(string inputText)
527529
{
528530
if (string.IsNullOrEmpty(inputText))
529-
return Array.Empty<int>();
531+
return new TokenizerResult(Array.Empty<long>(), Array.Empty<long>());
530532

531533
var metadata = await _tokenizer.GetMetadataAsync();
532534
var inputTensor = new DenseTensor<string>(new string[] { inputText }, new int[] { 1 });
533535
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
534536
{
535537
inferenceParameters.AddInputTensor(inputTensor);
536538
inferenceParameters.AddOutputBuffer();
537-
539+
inferenceParameters.AddOutputBuffer();
538540
using (var results = _tokenizer.RunInference(inferenceParameters))
539541
{
540-
var resultData = results.First().ToArray<long>();
541-
return Array.ConvertAll(resultData, Convert.ToInt32);
542+
return new TokenizerResult(results[0].ToArray<long>(), results[1].ToArray<long>());
542543
}
543544
}
544545
}
@@ -549,21 +550,21 @@ protected async Task<int[]> DecodePromptTextAsync(string inputText)
549550
/// </summary>
550551
/// <param name="tokenizedInput">The tokenized input.</param>
551552
/// <returns></returns>
552-
protected async Task<float[]> EncodePromptTokensAsync(int[] tokenizedInput)
553+
protected async Task<EncoderResult> EncodePromptTokensAsync(TokenizerResult tokenizedInput)
553554
{
554-
var inputDim = new[] { 1, tokenizedInput.Length };
555-
var outputDim = new[] { 1, tokenizedInput.Length, _tokenizer.TokenizerLength };
556555
var metadata = await _textEncoder.GetMetadataAsync();
557-
var inputTensor = new DenseTensor<int>(tokenizedInput, inputDim);
556+
var inputTensor = new DenseTensor<int>(tokenizedInput.InputIds.ToInt(), new[] { 1, tokenizedInput.InputIds.Length });
558557
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
559558
{
560559
inferenceParameters.AddInputTensor(inputTensor);
561-
inferenceParameters.AddOutputBuffer(outputDim);
560+
inferenceParameters.AddOutputBuffer(new[] { 1, tokenizedInput.InputIds.Length, _tokenizer.TokenizerLength });
561+
inferenceParameters.AddOutputBuffer(new int[] { 1, _tokenizer.TokenizerLength });
562562

563563
var results = await _textEncoder.RunInferenceAsync(inferenceParameters);
564-
using (var result = results.First())
564+
using (var promptEmbeds = results.Last())
565+
using (var promptEmbedsPooled = results.First())
565566
{
566-
return result.ToArray();
567+
return new EncoderResult(promptEmbeds.ToDenseTensor(), promptEmbedsPooled.ToDenseTensor());
567568
}
568569
}
569570
}
@@ -575,22 +576,44 @@ protected async Task<float[]> EncodePromptTokensAsync(int[] tokenizedInput)
575576
/// <param name="inputTokens">The input tokens.</param>
576577
/// <param name="minimumLength">The minimum length.</param>
577578
/// <returns></returns>
578-
protected async Task<DenseTensor<float>> GeneratePromptEmbedsAsync(int[] inputTokens, int minimumLength)
579+
protected async Task<PromptEmbeddingsResult> GeneratePromptEmbedsAsync(TokenizerResult inputTokens, int minimumLength)
579580
{
580581
// If less than minimumLength pad with blank tokens
581-
if (inputTokens.Length < minimumLength)
582-
inputTokens = PadWithBlankTokens(inputTokens, minimumLength).ToArray();
582+
if (inputTokens.InputIds.Length < minimumLength)
583+
{
584+
inputTokens.InputIds = PadWithBlankTokens(inputTokens.InputIds, minimumLength, _tokenizer.PadTokenId).ToArray();
585+
inputTokens.AttentionMask = PadWithBlankTokens(inputTokens.AttentionMask, minimumLength, 1).ToArray();
586+
}
587+
588+
// The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate1
589+
var tokenBatches = new List<long[]>();
590+
var attentionBatches = new List<long[]>();
591+
foreach (var tokenBatch in inputTokens.InputIds.Batch(_tokenizer.TokenizerLimit))
592+
tokenBatches.Add(PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit, _tokenizer.PadTokenId).ToArray());
593+
foreach (var attentionBatch in inputTokens.AttentionMask.Batch(_tokenizer.TokenizerLimit))
594+
attentionBatches.Add(PadWithBlankTokens(attentionBatch, _tokenizer.TokenizerLimit, 1).ToArray());
583595

584-
// The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate
585-
var embeddings = new List<float>();
586-
foreach (var tokenBatch in inputTokens.Batch(_tokenizer.TokenizerLimit))
596+
597+
var promptEmbeddings = new List<float>();
598+
var pooledPromptEmbeddings = new List<float>();
599+
for (int i = 0; i < tokenBatches.Count; i++)
587600
{
588-
var tokens = PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit);
589-
embeddings.AddRange(await EncodePromptTokensAsync(tokens.ToArray()));
601+
var result = await EncodePromptTokensAsync(new TokenizerResult(tokenBatches[i], attentionBatches[i]));
602+
promptEmbeddings.AddRange(result.PromptEmbeds);
603+
pooledPromptEmbeddings.AddRange(result.PooledPromptEmbeds);
590604
}
591605

592-
var dim = new[] { 1, embeddings.Count / _tokenizer.TokenizerLength, _tokenizer.TokenizerLength };
593-
return new DenseTensor<float>(embeddings.ToArray(), dim);
606+
607+
//var embeddingsDim = new[] { 1, promptEmbeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength };
608+
//var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), embeddingsDim);
609+
610+
////TODO: Pooled embeds do not support more than 77 tokens, just grab first set
611+
//var pooledDim = new[] { 1, _tokenizer2.TokenizerLength };
612+
//var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.Take(_tokenizer2.TokenizerLength).ToArray(), pooledDim);
613+
614+
var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), new[] { 1, promptEmbeddings.Count / _tokenizer.TokenizerLength, _tokenizer.TokenizerLength });
615+
var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.ToArray(), new[] { 1, _tokenizer.TokenizerLimit, _tokenizer.TokenizerLength });
616+
return new PromptEmbeddingsResult(promptTensor, pooledTensor);
594617
}
595618

596619

@@ -600,11 +623,11 @@ protected async Task<DenseTensor<float>> GeneratePromptEmbedsAsync(int[] inputTo
600623
/// <param name="inputs">The inputs.</param>
601624
/// <param name="requiredLength">The the required length of the returned array.</param>
602625
/// <returns></returns>
603-
protected IEnumerable<int> PadWithBlankTokens(IEnumerable<int> inputs, int requiredLength)
626+
protected IEnumerable<long> PadWithBlankTokens(IEnumerable<long> inputs, int requiredLength, int padTokenId)
604627
{
605628
var count = inputs.Count();
606629
if (requiredLength > count)
607-
return inputs.Concat(Enumerable.Repeat(_tokenizer.PadTokenId, requiredLength - count));
630+
return inputs.Concat(Enumerable.Repeat((long)padTokenId, requiredLength - count));
608631
return inputs;
609632
}
610633

OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsTwoAsync(PromptOptions pr
140140
/// Tokenize Prompt and NegativePrompt with Tokenizer2
141141
var promptTokens = await DecodeTextAsLongAsync(promptOptions.Prompt);
142142
var negativePromptTokens = await DecodeTextAsLongAsync(promptOptions.NegativePrompt);
143-
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
143+
var maxPromptTokenCount = Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);
144144

145145
// Generate embeds for tokens
146146
var promptEmbeddings = await GenerateEmbedsAsync(promptTokens, maxPromptTokenCount);
@@ -174,7 +174,7 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
174174
// Tokenize Prompt and NegativePrompt
175175
var promptTokens = await DecodePromptTextAsync(promptOptions.Prompt);
176176
var negativePromptTokens = await DecodePromptTextAsync(promptOptions.NegativePrompt);
177-
var maxPromptTokenCount = Math.Max(promptTokens.Length, negativePromptTokens.Length);
177+
var maxPromptTokenCount = Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);
178178

179179
// Generate embeds for tokens
180180
var promptEmbeddings = await GeneratePromptEmbedsAsync(promptTokens, maxPromptTokenCount);
@@ -188,8 +188,8 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
188188
var dualPromptEmbeddings = await GenerateEmbedsAsync(dualPromptTokens, maxPromptTokenCount);
189189
var dualNegativePromptEmbeddings = await GenerateEmbedsAsync(dualNegativePromptTokens, maxPromptTokenCount);
190190

191-
var dualPrompt = promptEmbeddings.Concatenate(dualPromptEmbeddings.PromptEmbeds, 2);
192-
var dualNegativePrompt = negativePromptEmbeddings.Concatenate(dualNegativePromptEmbeddings.PromptEmbeds, 2);
191+
var dualPrompt = promptEmbeddings.PromptEmbeds.Concatenate(dualPromptEmbeddings.PromptEmbeds, 2);
192+
var dualNegativePrompt = negativePromptEmbeddings.PromptEmbeds.Concatenate(dualNegativePromptEmbeddings.PromptEmbeds, 2);
193193
var pooledPromptEmbeds = dualPromptEmbeddings.PooledPromptEmbeds;
194194
var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings.PooledPromptEmbeds;
195195

@@ -212,22 +212,21 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
212212
/// </summary>
213213
/// <param name="inputText">The input text.</param>
214214
/// <returns></returns>
215-
private async Task<long[]> DecodeTextAsLongAsync(string inputText)
215+
private async Task<TokenizerResult> DecodeTextAsLongAsync(string inputText)
216216
{
217217
if (string.IsNullOrEmpty(inputText))
218-
return Array.Empty<long>();
218+
return new TokenizerResult(Array.Empty<long>(), Array.Empty<long>());
219219

220220
var metadata = await _tokenizer2.GetMetadataAsync();
221221
var inputTensor = new DenseTensor<string>(new string[] { inputText }, new int[] { 1 });
222222
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
223223
{
224224
inferenceParameters.AddInputTensor(inputTensor);
225225
inferenceParameters.AddOutputBuffer();
226-
227-
using (var results = _tokenizer2.RunInference(inferenceParameters))
226+
inferenceParameters.AddOutputBuffer();
227+
using (var results = _tokenizer.RunInference(inferenceParameters))
228228
{
229-
var resultData = results.First().ToArray<long>();
230-
return resultData;
229+
return new TokenizerResult(results[0].ToArray<long>(), results[1].ToArray<long>());
231230
}
232231
}
233232
}
@@ -238,16 +237,16 @@ private async Task<long[]> DecodeTextAsLongAsync(string inputText)
238237
/// </summary>
239238
/// <param name="tokenizedInput">The tokenized input.</param>
240239
/// <returns></returns>
241-
private async Task<EncoderResult> EncodeTokensAsync(long[] tokenizedInput)
240+
private async Task<EncoderResult> EncodeTokensAsync(TokenizerResult tokenizedInput)
242241
{
243242
var metadata = await _textEncoder2.GetMetadataAsync();
244-
var inputTensor = new DenseTensor<long>(tokenizedInput, new[] { 1, tokenizedInput.Length });
243+
var inputTensor = new DenseTensor<long>(tokenizedInput.InputIds, new[] { 1, tokenizedInput.InputIds.Length });
245244
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
246245
{
247246
int hiddenStateIndex = metadata.Outputs.Count - 2;
248247
inferenceParameters.AddInputTensor(inputTensor);
249248
inferenceParameters.AddOutputBuffer(new[] { 1, _tokenizer2.TokenizerLength });
250-
inferenceParameters.AddOutputBuffer(hiddenStateIndex, new[] { 1, tokenizedInput.Length, _tokenizer2.TokenizerLength });
249+
inferenceParameters.AddOutputBuffer(hiddenStateIndex, new[] { 1, tokenizedInput.InputIds.Length, _tokenizer2.TokenizerLength });
251250

252251
var results = await _textEncoder2.RunInferenceAsync(inferenceParameters);
253252
var promptEmbeds = results.Last().ToDenseTensor();
@@ -263,30 +262,43 @@ private async Task<EncoderResult> EncodeTokensAsync(long[] tokenizedInput)
263262
/// <param name="inputTokens">The input tokens.</param>
264263
/// <param name="minimumLength">The minimum length.</param>
265264
/// <returns></returns>
266-
private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(long[] inputTokens, int minimumLength)
265+
private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(TokenizerResult inputTokens, int minimumLength)
267266
{
268267
// If less than minimumLength pad with blank tokens
269-
if (inputTokens.Length < minimumLength)
270-
inputTokens = PadWithBlankTokens(inputTokens, minimumLength).ToArray();
268+
if (inputTokens.InputIds.Length < minimumLength)
269+
{
270+
inputTokens.InputIds = PadWithBlankTokens(inputTokens.InputIds, minimumLength, _tokenizer.PadTokenId).ToArray();
271+
inputTokens.AttentionMask = PadWithBlankTokens(inputTokens.AttentionMask, minimumLength, 1).ToArray();
272+
}
271273

272274
// The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate1
273-
var embeddings = new List<float>();
274-
var pooledEmbeds = new List<float>();
275-
foreach (var tokenBatch in inputTokens.Batch(_tokenizer2.TokenizerLimit))
276-
{
277-
var tokens = PadWithBlankTokens(tokenBatch, _tokenizer2.TokenizerLimit);
278-
var result = await EncodeTokensAsync(tokens.ToArray());
275+
var tokenBatches = new List<long[]>();
276+
var attentionBatches = new List<long[]>();
277+
foreach (var tokenBatch in inputTokens.InputIds.Batch(_tokenizer.TokenizerLimit))
278+
tokenBatches.Add(PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit, _tokenizer.PadTokenId).ToArray());
279+
foreach (var attentionBatch in inputTokens.AttentionMask.Batch(_tokenizer.TokenizerLimit))
280+
attentionBatches.Add(PadWithBlankTokens(attentionBatch, _tokenizer.TokenizerLimit, 1).ToArray());
281+
279282

280-
embeddings.AddRange(result.PromptEmbeds);
281-
pooledEmbeds.AddRange(result.PooledPromptEmbeds);
283+
var promptEmbeddings = new List<float>();
284+
var pooledPromptEmbeddings = new List<float>();
285+
for (int i = 0; i < tokenBatches.Count; i++)
286+
{
287+
var result = await EncodeTokensAsync(new TokenizerResult(tokenBatches[i], attentionBatches[i]));
288+
promptEmbeddings.AddRange(result.PromptEmbeds);
289+
pooledPromptEmbeddings.AddRange(result.PooledPromptEmbeds);
282290
}
291+
292+
293+
//var embeddingsDim = new[] { 1, promptEmbeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength };
294+
//var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), embeddingsDim);
283295

284-
var embeddingsDim = new[] { 1, embeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength };
285-
var promptTensor = new DenseTensor<float>(embeddings.ToArray(), embeddingsDim);
296+
////TODO: Pooled embeds do not support more than 77 tokens, just grab first set
297+
//var pooledDim = new[] { 1, _tokenizer2.TokenizerLength };
298+
//var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.Take(_tokenizer2.TokenizerLength).ToArray(), pooledDim);
286299

287-
//TODO: Pooled embeds do not support more than 77 tokens, just grab first set
288-
var pooledDim = new[] { 1, _tokenizer2.TokenizerLength };
289-
var pooledTensor = new DenseTensor<float>(pooledEmbeds.Take(_tokenizer2.TokenizerLength).ToArray(), pooledDim);
300+
var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), new[] { 1, promptEmbeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength });
301+
var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.ToArray(), new[] { 1, pooledPromptEmbeddings.Count });
290302
return new PromptEmbeddingsResult(promptTensor, pooledTensor);
291303
}
292304

@@ -297,11 +309,11 @@ private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(long[] inputToken
297309
/// <param name="inputs">The inputs.</param>
298310
/// <param name="requiredLength">Length of the required.</param>
299311
/// <returns></returns>
300-
private IEnumerable<long> PadWithBlankTokens(IEnumerable<long> inputs, int requiredLength)
312+
private IEnumerable<long> PadWithBlankTokens(IEnumerable<long> inputs, int requiredLength, int padTokenId)
301313
{
302314
var count = inputs.Count();
303315
if (requiredLength > count)
304-
return inputs.Concat(Enumerable.Repeat((long)_tokenizer.PadTokenId, requiredLength - count));
316+
return inputs.Concat(Enumerable.Repeat((long)padTokenId, requiredLength - count));
305317
return inputs;
306318
}
307319

0 commit comments

Comments
 (0)