Skip to content

Commit 3da3841

Browse files
authored
Support decoding of tensors (#416)
* Support decoding of tensors (Closes #362) * Remove debug line
1 parent 768a2e2 commit 3da3841

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

src/tokenizers.js

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,26 @@ function objectToMap(obj) {
122122
return new Map(Object.entries(obj));
123123
}
124124

125+
/**
126+
* Helper function to convert a tensor to a list before decoding.
127+
* @param {Tensor} tensor The tensor to convert.
128+
* @returns {number[]} The tensor as a list.
129+
*/
130+
function prepareTensorForDecode(tensor) {
131+
const dims = tensor.dims;
132+
switch (dims.length) {
133+
case 1:
134+
return tensor.tolist();
135+
case 2:
136+
if (dims[0] !== 1) {
137+
throw new Error('Unable to decode tensor with `batch size !== 1`. Use `tokenizer.batch_decode(...)` for batched inputs.');
138+
}
139+
return tensor.tolist()[0];
140+
default:
141+
throw new Error(`Expected tensor to have 1-2 dimensions, got ${dims.length}.`)
142+
}
143+
}
144+
125145
/**
126146
* Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms
127147
* @param {string} text The text to clean up.
@@ -2556,18 +2576,21 @@ export class PreTrainedTokenizer extends Callable {
25562576

25572577
/**
25582578
* Decode a batch of tokenized sequences.
2559-
* @param {number[][]} batch List of tokenized input sequences.
2579+
* @param {number[][]|Tensor} batch List/Tensor of tokenized input sequences.
25602580
* @param {Object} decode_args (Optional) Object with decoding arguments.
25612581
* @returns {string[]} List of decoded sequences.
25622582
*/
25632583
batch_decode(batch, decode_args = {}) {
2584+
if (batch instanceof Tensor) {
2585+
batch = batch.tolist();
2586+
}
25642587
return batch.map(x => this.decode(x, decode_args));
25652588
}
25662589

25672590
/**
25682591
* Decodes a sequence of token IDs back to a string.
25692592
*
2570-
* @param {number[]} token_ids List of token IDs to decode.
2593+
* @param {number[]|Tensor} token_ids List/Tensor of token IDs to decode.
25712594
* @param {Object} [decode_args={}]
25722595
* @param {boolean} [decode_args.skip_special_tokens=false] If true, special tokens are removed from the output string.
25732596
* @param {boolean} [decode_args.clean_up_tokenization_spaces=true] If true, spaces before punctuations and abbreviated forms are removed.
@@ -2579,6 +2602,10 @@ export class PreTrainedTokenizer extends Callable {
25792602
token_ids,
25802603
decode_args = {},
25812604
) {
2605+
if (token_ids instanceof Tensor) {
2606+
token_ids = prepareTensorForDecode(token_ids);
2607+
}
2608+
25822609
if (!Array.isArray(token_ids) || token_ids.length === 0 || !isIntegralNumber(token_ids[0])) {
25832610
throw Error("token_ids must be a non-empty array of integers.");
25842611
}
@@ -3458,6 +3485,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
34583485
let text;
34593486
// @ts-ignore
34603487
if (decode_args && decode_args.decode_with_timestamps) {
3488+
if (token_ids instanceof Tensor) {
3489+
token_ids = prepareTensorForDecode(token_ids);
3490+
}
34613491
text = this.decodeWithTimestamps(token_ids, decode_args);
34623492
} else {
34633493
text = super.decode(token_ids, decode_args);

tests/tokenizers.test.js

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,27 @@ describe('Edge cases', () => {
5757
compare(token_ids, [101, 100, 102])
5858
}, 5000); // NOTE: 5 seconds
5959
});
60+
61+
describe('Extra decoding tests', () => {
62+
it('should be able to decode the output of encode', async () => {
63+
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
64+
65+
let text = 'hello world!';
66+
67+
// Ensure all the following outputs are the same:
68+
// 1. Tensor of ids: allow decoding of 1D or 2D tensors.
69+
let encodedTensor = tokenizer(text);
70+
let decoded1 = tokenizer.decode(encodedTensor.input_ids, { skip_special_tokens: true });
71+
let decoded2 = tokenizer.batch_decode(encodedTensor.input_ids, { skip_special_tokens: true })[0];
72+
expect(decoded1).toEqual(text);
73+
expect(decoded2).toEqual(text);
74+
75+
// 2. List of ids
76+
let encodedList = tokenizer(text, { return_tensor: false });
77+
let decoded3 = tokenizer.decode(encodedList.input_ids, { skip_special_tokens: true });
78+
let decoded4 = tokenizer.batch_decode([encodedList.input_ids], { skip_special_tokens: true })[0];
79+
expect(decoded3).toEqual(text);
80+
expect(decoded4).toEqual(text);
81+
82+
}, MAX_TEST_EXECUTION_TIME);
83+
});

0 commit comments

Comments
 (0)