@@ -140,7 +140,7 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsTwoAsync(PromptOptions pr
140
140
/// Tokenize Prompt and NegativePrompt with Tokenizer2
141
141
var promptTokens = await DecodeTextAsLongAsync ( promptOptions . Prompt ) ;
142
142
var negativePromptTokens = await DecodeTextAsLongAsync ( promptOptions . NegativePrompt ) ;
143
- var maxPromptTokenCount = Math . Max ( promptTokens . Length , negativePromptTokens . Length ) ;
143
+ var maxPromptTokenCount = Math . Max ( promptTokens . InputIds . Length , negativePromptTokens . InputIds . Length ) ;
144
144
145
145
// Generate embeds for tokens
146
146
var promptEmbeddings = await GenerateEmbedsAsync ( promptTokens , maxPromptTokenCount ) ;
@@ -174,7 +174,7 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
174
174
// Tokenize Prompt and NegativePrompt
175
175
var promptTokens = await DecodePromptTextAsync ( promptOptions . Prompt ) ;
176
176
var negativePromptTokens = await DecodePromptTextAsync ( promptOptions . NegativePrompt ) ;
177
- var maxPromptTokenCount = Math . Max ( promptTokens . Length , negativePromptTokens . Length ) ;
177
+ var maxPromptTokenCount = Math . Max ( promptTokens . InputIds . Length , negativePromptTokens . InputIds . Length ) ;
178
178
179
179
// Generate embeds for tokens
180
180
var promptEmbeddings = await GeneratePromptEmbedsAsync ( promptTokens , maxPromptTokenCount ) ;
@@ -188,8 +188,8 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
188
188
var dualPromptEmbeddings = await GenerateEmbedsAsync ( dualPromptTokens , maxPromptTokenCount ) ;
189
189
var dualNegativePromptEmbeddings = await GenerateEmbedsAsync ( dualNegativePromptTokens , maxPromptTokenCount ) ;
190
190
191
- var dualPrompt = promptEmbeddings . Concatenate ( dualPromptEmbeddings . PromptEmbeds , 2 ) ;
192
- var dualNegativePrompt = negativePromptEmbeddings . Concatenate ( dualNegativePromptEmbeddings . PromptEmbeds , 2 ) ;
191
+ var dualPrompt = promptEmbeddings . PromptEmbeds . Concatenate ( dualPromptEmbeddings . PromptEmbeds , 2 ) ;
192
+ var dualNegativePrompt = negativePromptEmbeddings . PromptEmbeds . Concatenate ( dualNegativePromptEmbeddings . PromptEmbeds , 2 ) ;
193
193
var pooledPromptEmbeds = dualPromptEmbeddings . PooledPromptEmbeds ;
194
194
var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings . PooledPromptEmbeds ;
195
195
@@ -212,22 +212,21 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
212
212
/// </summary>
213
213
/// <param name="inputText">The input text.</param>
214
214
/// <returns></returns>
215
- private async Task < long [ ] > DecodeTextAsLongAsync ( string inputText )
215
+ private async Task < TokenizerResult > DecodeTextAsLongAsync ( string inputText )
216
216
{
217
217
if ( string . IsNullOrEmpty ( inputText ) )
218
- return Array . Empty < long > ( ) ;
218
+ return new TokenizerResult ( Array . Empty < long > ( ) , Array . Empty < long > ( ) ) ;
219
219
220
220
var metadata = await _tokenizer2 . GetMetadataAsync ( ) ;
221
221
var inputTensor = new DenseTensor < string > ( new string [ ] { inputText } , new int [ ] { 1 } ) ;
222
222
using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
223
223
{
224
224
inferenceParameters . AddInputTensor ( inputTensor ) ;
225
225
inferenceParameters . AddOutputBuffer ( ) ;
226
-
227
- using ( var results = _tokenizer2 . RunInference ( inferenceParameters ) )
226
+ inferenceParameters . AddOutputBuffer ( ) ;
227
+ using ( var results = _tokenizer . RunInference ( inferenceParameters ) )
228
228
{
229
- var resultData = results . First ( ) . ToArray < long > ( ) ;
230
- return resultData ;
229
+ return new TokenizerResult ( results [ 0 ] . ToArray < long > ( ) , results [ 1 ] . ToArray < long > ( ) ) ;
231
230
}
232
231
}
233
232
}
@@ -238,16 +237,16 @@ private async Task<long[]> DecodeTextAsLongAsync(string inputText)
238
237
/// </summary>
239
238
/// <param name="tokenizedInput">The tokenized input.</param>
240
239
/// <returns></returns>
241
- private async Task < EncoderResult > EncodeTokensAsync ( long [ ] tokenizedInput )
240
+ private async Task < EncoderResult > EncodeTokensAsync ( TokenizerResult tokenizedInput )
242
241
{
243
242
var metadata = await _textEncoder2 . GetMetadataAsync ( ) ;
244
- var inputTensor = new DenseTensor < long > ( tokenizedInput , new [ ] { 1 , tokenizedInput . Length } ) ;
243
+ var inputTensor = new DenseTensor < long > ( tokenizedInput . InputIds , new [ ] { 1 , tokenizedInput . InputIds . Length } ) ;
245
244
using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
246
245
{
247
246
int hiddenStateIndex = metadata . Outputs . Count - 2 ;
248
247
inferenceParameters . AddInputTensor ( inputTensor ) ;
249
248
inferenceParameters . AddOutputBuffer ( new [ ] { 1 , _tokenizer2 . TokenizerLength } ) ;
250
- inferenceParameters . AddOutputBuffer ( hiddenStateIndex , new [ ] { 1 , tokenizedInput . Length , _tokenizer2 . TokenizerLength } ) ;
249
+ inferenceParameters . AddOutputBuffer ( hiddenStateIndex , new [ ] { 1 , tokenizedInput . InputIds . Length , _tokenizer2 . TokenizerLength } ) ;
251
250
252
251
var results = await _textEncoder2 . RunInferenceAsync ( inferenceParameters ) ;
253
252
var promptEmbeds = results . Last ( ) . ToDenseTensor ( ) ;
@@ -263,30 +262,43 @@ private async Task<EncoderResult> EncodeTokensAsync(long[] tokenizedInput)
263
262
/// <param name="inputTokens">The input tokens.</param>
264
263
/// <param name="minimumLength">The minimum length.</param>
265
264
/// <returns></returns>
266
- private async Task < PromptEmbeddingsResult > GenerateEmbedsAsync ( long [ ] inputTokens , int minimumLength )
265
+ private async Task < PromptEmbeddingsResult > GenerateEmbedsAsync ( TokenizerResult inputTokens , int minimumLength )
267
266
{
268
267
// If less than minimumLength pad with blank tokens
269
- if ( inputTokens . Length < minimumLength )
270
- inputTokens = PadWithBlankTokens ( inputTokens , minimumLength ) . ToArray ( ) ;
268
+ if ( inputTokens . InputIds . Length < minimumLength )
269
+ {
270
+ inputTokens . InputIds = PadWithBlankTokens ( inputTokens . InputIds , minimumLength , _tokenizer . PadTokenId ) . ToArray ( ) ;
271
+ inputTokens . AttentionMask = PadWithBlankTokens ( inputTokens . AttentionMask , minimumLength , 1 ) . ToArray ( ) ;
272
+ }
271
273
272
274
// The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate1
273
- var embeddings = new List < float > ( ) ;
274
- var pooledEmbeds = new List < float > ( ) ;
275
- foreach ( var tokenBatch in inputTokens . Batch ( _tokenizer2 . TokenizerLimit ) )
276
- {
277
- var tokens = PadWithBlankTokens ( tokenBatch , _tokenizer2 . TokenizerLimit ) ;
278
- var result = await EncodeTokensAsync ( tokens . ToArray ( ) ) ;
275
+ var tokenBatches = new List < long [ ] > ( ) ;
276
+ var attentionBatches = new List < long [ ] > ( ) ;
277
+ foreach ( var tokenBatch in inputTokens . InputIds . Batch ( _tokenizer . TokenizerLimit ) )
278
+ tokenBatches . Add ( PadWithBlankTokens ( tokenBatch , _tokenizer . TokenizerLimit , _tokenizer . PadTokenId ) . ToArray ( ) ) ;
279
+ foreach ( var attentionBatch in inputTokens . AttentionMask . Batch ( _tokenizer . TokenizerLimit ) )
280
+ attentionBatches . Add ( PadWithBlankTokens ( attentionBatch , _tokenizer . TokenizerLimit , 1 ) . ToArray ( ) ) ;
281
+
279
282
280
- embeddings . AddRange ( result . PromptEmbeds ) ;
281
- pooledEmbeds . AddRange ( result . PooledPromptEmbeds ) ;
283
+ var promptEmbeddings = new List < float > ( ) ;
284
+ var pooledPromptEmbeddings = new List < float > ( ) ;
285
+ for ( int i = 0 ; i < tokenBatches . Count ; i ++ )
286
+ {
287
+ var result = await EncodeTokensAsync ( new TokenizerResult ( tokenBatches [ i ] , attentionBatches [ i ] ) ) ;
288
+ promptEmbeddings . AddRange ( result . PromptEmbeds ) ;
289
+ pooledPromptEmbeddings . AddRange ( result . PooledPromptEmbeds ) ;
282
290
}
291
+
292
+
293
+ //var embeddingsDim = new[] { 1, promptEmbeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength };
294
+ //var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), embeddingsDim);
283
295
284
- var embeddingsDim = new [ ] { 1 , embeddings . Count / _tokenizer2 . TokenizerLength , _tokenizer2 . TokenizerLength } ;
285
- var promptTensor = new DenseTensor < float > ( embeddings . ToArray ( ) , embeddingsDim ) ;
296
+ ////TODO: Pooled embeds do not support more than 77 tokens, just grab first set
297
+ //var pooledDim = new[] { 1, _tokenizer2.TokenizerLength };
298
+ //var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.Take(_tokenizer2.TokenizerLength).ToArray(), pooledDim);
286
299
287
- //TODO: Pooled embeds do not support more than 77 tokens, just grab first set
288
- var pooledDim = new [ ] { 1 , _tokenizer2 . TokenizerLength } ;
289
- var pooledTensor = new DenseTensor < float > ( pooledEmbeds . Take ( _tokenizer2 . TokenizerLength ) . ToArray ( ) , pooledDim ) ;
300
+ var promptTensor = new DenseTensor < float > ( promptEmbeddings . ToArray ( ) , new [ ] { 1 , promptEmbeddings . Count / _tokenizer2 . TokenizerLength , _tokenizer2 . TokenizerLength } ) ;
301
+ var pooledTensor = new DenseTensor < float > ( pooledPromptEmbeddings . ToArray ( ) , new [ ] { 1 , pooledPromptEmbeddings . Count } ) ;
290
302
return new PromptEmbeddingsResult ( promptTensor , pooledTensor ) ;
291
303
}
292
304
@@ -297,11 +309,11 @@ private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(long[] inputToken
297
309
/// <param name="inputs">The inputs.</param>
298
310
/// <param name="requiredLength">Length of the required.</param>
299
311
/// <returns></returns>
300
- private IEnumerable < long > PadWithBlankTokens ( IEnumerable < long > inputs , int requiredLength )
312
+ private IEnumerable < long > PadWithBlankTokens ( IEnumerable < long > inputs , int requiredLength , int padTokenId )
301
313
{
302
314
var count = inputs . Count ( ) ;
303
315
if ( requiredLength > count )
304
- return inputs . Concat ( Enumerable . Repeat ( ( long ) _tokenizer . PadTokenId , requiredLength - count ) ) ;
316
+ return inputs . Concat ( Enumerable . Repeat ( ( long ) padTokenId , requiredLength - count ) ) ;
305
317
return inputs ;
306
318
}
307
319
0 commit comments