@@ -122,6 +122,26 @@ function objectToMap(obj) {
122
122
return new Map ( Object . entries ( obj ) ) ;
123
123
}
124
124
125
+ /**
126
+ * Helper function to convert a tensor to a list before decoding.
127
+ * @param {Tensor } tensor The tensor to convert.
128
+ * @returns {number[] } The tensor as a list.
129
+ */
130
+ function prepareTensorForDecode ( tensor ) {
131
+ const dims = tensor . dims ;
132
+ switch ( dims . length ) {
133
+ case 1 :
134
+ return tensor . tolist ( ) ;
135
+ case 2 :
136
+ if ( dims [ 0 ] !== 1 ) {
137
+ throw new Error ( 'Unable to decode tensor with `batch size !== 1`. Use `tokenizer.batch_decode(...)` for batched inputs.' ) ;
138
+ }
139
+ return tensor . tolist ( ) [ 0 ] ;
140
+ default :
141
+ throw new Error ( `Expected tensor to have 1-2 dimensions, got ${ dims . length } .` )
142
+ }
143
+ }
144
+
125
145
/**
126
146
* Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms
127
147
* @param {string } text The text to clean up.
@@ -2556,18 +2576,21 @@ export class PreTrainedTokenizer extends Callable {
2556
2576
2557
2577
/**
2558
2578
* Decode a batch of tokenized sequences.
2559
- * @param {number[][] } batch List of tokenized input sequences.
2579
+ * @param {number[][]|Tensor } batch List/Tensor of tokenized input sequences.
2560
2580
* @param {Object } decode_args (Optional) Object with decoding arguments.
2561
2581
* @returns {string[] } List of decoded sequences.
2562
2582
*/
2563
2583
batch_decode ( batch , decode_args = { } ) {
2584
+ if ( batch instanceof Tensor ) {
2585
+ batch = batch . tolist ( ) ;
2586
+ }
2564
2587
return batch . map ( x => this . decode ( x , decode_args ) ) ;
2565
2588
}
2566
2589
2567
2590
/**
2568
2591
* Decodes a sequence of token IDs back to a string.
2569
2592
*
2570
- * @param {number[] } token_ids List of token IDs to decode.
2593
+ * @param {number[]|Tensor } token_ids List/Tensor of token IDs to decode.
2571
2594
* @param {Object } [decode_args={}]
2572
2595
* @param {boolean } [decode_args.skip_special_tokens=false] If true, special tokens are removed from the output string.
2573
2596
* @param {boolean } [decode_args.clean_up_tokenization_spaces=true] If true, spaces before punctuations and abbreviated forms are removed.
@@ -2579,6 +2602,10 @@ export class PreTrainedTokenizer extends Callable {
2579
2602
token_ids ,
2580
2603
decode_args = { } ,
2581
2604
) {
2605
+ if ( token_ids instanceof Tensor ) {
2606
+ token_ids = prepareTensorForDecode ( token_ids ) ;
2607
+ }
2608
+
2582
2609
if ( ! Array . isArray ( token_ids ) || token_ids . length === 0 || ! isIntegralNumber ( token_ids [ 0 ] ) ) {
2583
2610
throw Error ( "token_ids must be a non-empty array of integers." ) ;
2584
2611
}
@@ -3458,6 +3485,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
3458
3485
let text ;
3459
3486
// @ts -ignore
3460
3487
if ( decode_args && decode_args . decode_with_timestamps ) {
3488
+ if ( token_ids instanceof Tensor ) {
3489
+ token_ids = prepareTensorForDecode ( token_ids ) ;
3490
+ }
3461
3491
text = this . decodeWithTimestamps ( token_ids , decode_args ) ;
3462
3492
} else {
3463
3493
text = super . decode ( token_ids , decode_args ) ;
0 commit comments