Skip to content

Support decoding of tensors (Closes #362) #416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ function objectToMap(obj) {
return new Map(Object.entries(obj));
}

/**
* Helper function to convert a tensor to a list before decoding.
* @param {Tensor} tensor The tensor to convert.
* @returns {number[]} The tensor as a list.
*/
function prepareTensorForDecode(tensor) {
const dims = tensor.dims;
switch (dims.length) {
case 1:
return tensor.tolist();
case 2:
if (dims[0] !== 1) {
throw new Error('Unable to decode tensor with `batch size !== 1`. Use `tokenizer.batch_decode(...)` for batched inputs.');
}
return tensor.tolist()[0];
default:
throw new Error(`Expected tensor to have 1-2 dimensions, got ${dims.length}.`)
}
}

/**
* Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms
* @param {string} text The text to clean up.
Expand Down Expand Up @@ -2556,18 +2576,21 @@ export class PreTrainedTokenizer extends Callable {

/**
* Decode a batch of tokenized sequences.
* @param {number[][]} batch List of tokenized input sequences.
* @param {number[][]|Tensor} batch List/Tensor of tokenized input sequences.
* @param {Object} decode_args (Optional) Object with decoding arguments.
* @returns {string[]} List of decoded sequences.
*/
batch_decode(batch, decode_args = {}) {
if (batch instanceof Tensor) {
batch = batch.tolist();
}
return batch.map(x => this.decode(x, decode_args));
}

/**
* Decodes a sequence of token IDs back to a string.
*
* @param {number[]} token_ids List of token IDs to decode.
* @param {number[]|Tensor} token_ids List/Tensor of token IDs to decode.
* @param {Object} [decode_args={}]
* @param {boolean} [decode_args.skip_special_tokens=false] If true, special tokens are removed from the output string.
* @param {boolean} [decode_args.clean_up_tokenization_spaces=true] If true, spaces before punctuations and abbreviated forms are removed.
Expand All @@ -2579,6 +2602,10 @@ export class PreTrainedTokenizer extends Callable {
token_ids,
decode_args = {},
) {
if (token_ids instanceof Tensor) {
token_ids = prepareTensorForDecode(token_ids);
}

if (!Array.isArray(token_ids) || token_ids.length === 0 || !isIntegralNumber(token_ids[0])) {
throw Error("token_ids must be a non-empty array of integers.");
}
Expand Down Expand Up @@ -3458,6 +3485,9 @@ export class WhisperTokenizer extends PreTrainedTokenizer {
let text;
// @ts-ignore
if (decode_args && decode_args.decode_with_timestamps) {
if (token_ids instanceof Tensor) {
token_ids = prepareTensorForDecode(token_ids);
}
text = this.decodeWithTimestamps(token_ids, decode_args);
} else {
text = super.decode(token_ids, decode_args);
Expand Down
24 changes: 24 additions & 0 deletions tests/tokenizers.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,27 @@ describe('Edge cases', () => {
compare(token_ids, [101, 100, 102])
}, 5000); // NOTE: 5 seconds
});

describe('Extra decoding tests', () => {
it('should be able to decode the output of encode', async () => {
let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');

let text = 'hello world!';

// Ensure all the following outputs are the same:
// 1. Tensor of ids: allow decoding of 1D or 2D tensors.
let encodedTensor = tokenizer(text);
let decoded1 = tokenizer.decode(encodedTensor.input_ids, { skip_special_tokens: true });
let decoded2 = tokenizer.batch_decode(encodedTensor.input_ids, { skip_special_tokens: true })[0];
expect(decoded1).toEqual(text);
expect(decoded2).toEqual(text);

// 2. List of ids
let encodedList = tokenizer(text, { return_tensor: false });
let decoded3 = tokenizer.decode(encodedList.input_ids, { skip_special_tokens: true });
let decoded4 = tokenizer.batch_decode([encodedList.input_ids], { skip_special_tokens: true })[0];
expect(decoded3).toEqual(text);
expect(decoded4).toEqual(text);

}, MAX_TEST_EXECUTION_TIME);
});