Skip to content

Commit 11fdae5

Browse files
authored
Add special tokens in text-generation pipeline if tokenizer requires (#1370)
* Add special tokens in text-generation pipeline if tokenizer requires * Fix logits processors tests * Update bundles.test.js * Update comment * Formatting
1 parent 2c32e1d commit 11fdae5

File tree

6 files changed

+83
-44
lines changed

6 files changed

+83
-44
lines changed

src/pipelines.js

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,11 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
996996
let isBatched = false;
997997
let isChatInput = false;
998998

999+
// By default, do not add special tokens, unless the tokenizer specifies otherwise
1000+
let add_special_tokens = generate_kwargs.add_special_tokens
1001+
?? (this.tokenizer.add_bos_token || this.tokenizer.add_eos_token)
1002+
?? false;
1003+
9991004
// Normalize inputs
10001005
/** @type {string[]} */
10011006
let inputs;
@@ -1021,11 +1026,9 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
10211026
add_generation_prompt: true,
10221027
})
10231028
));
1029+
add_special_tokens = false; // Chat template handles this already
10241030
}
10251031

1026-
// By default, do not add special tokens
1027-
const add_special_tokens = generate_kwargs.add_special_tokens ?? false;
1028-
10291032
// By default, return full text
10301033
const return_full_text = isChatInput
10311034
? false

src/tokenizers.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2659,6 +2659,9 @@ export class PreTrainedTokenizer extends Callable {
26592659
this.padding_side = tokenizerConfig.padding_side;
26602660
}
26612661

2662+
this.add_bos_token = tokenizerConfig.add_bos_token;
2663+
this.add_eos_token = tokenizerConfig.add_eos_token;
2664+
26622665
this.legacy = false;
26632666

26642667
this.chat_template = tokenizerConfig.chat_template ?? null;

tests/bundles.test.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ const result = await generator("hello", { max_new_tokens: 3, return_full_text: f
99
process.stdout.write(result[0].generated_text);
1010
`;
1111

12-
const TARGET_OUTPUT = "erdingsAndroid Load";
12+
const TARGET_OUTPUT = "erdingsdelete mely";
1313

1414
const wrap_async_iife = (code) => `(async function() { ${code} })();`;
1515

1616
const check = (code, module = false) => {
1717
const args = ["-e", code];
1818
if (module) args.push("--input-type=module");
1919
const { status, stdout, stderr } = spawnSync("node", args);
20-
expect(stderr.toString()).toBe(""); // No warnings or errors are printed
21-
expect(stdout.toString()).toBe(TARGET_OUTPUT); // The output should match
22-
expect(status).toBe(0); // The process should exit cleanly
20+
expect(stderr.toString()).toEqual(""); // No warnings or errors are printed
21+
expect(stdout.toString()).toEqual(TARGET_OUTPUT); // The output should match
22+
expect(status).toEqual(0); // The process should exit cleanly
2323
};
2424

2525
describe("Testing the bundle", () => {

tests/pipelines/test_pipelines_text_generation.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export default () => {
2020

2121
describe("batch_size=1", () => {
2222
const text_input = "hello";
23-
const generated_text_target = "erdingsAndroid Load";
23+
const generated_text_target = "erdingsdelete mely";
2424
const text_target = [{ generated_text: text_input + generated_text_target }];
2525
const new_text_target = [{ generated_text: generated_text_target }];
2626

tests/tokenizers.test.js

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ describe("Tokenizer padding/truncation", () => {
5454
}, MAX_TOKENIZER_LOAD_TIME);
5555

5656
describe("return_tensor=false (jagged array)", () => {
57-
5857
test("jagged array output when return_tensor is false", () => {
5958
const output = tokenizer(inputs, {
6059
return_tensor: false,
@@ -105,7 +104,6 @@ describe("Tokenizer padding/truncation", () => {
105104
compare(output, expected);
106105
});
107106

108-
109107
test("No padding, max_length=3 (implicit truncation strategy)", () => {
110108
const output = tokenizer(inputs_2, {
111109
padding: false,
@@ -129,9 +127,18 @@ describe("Tokenizer padding/truncation", () => {
129127
return_tensor: false,
130128
});
131129
const expected = {
132-
input_ids: [[1037, 0, 0, 0, 0], [1038, 1039, 1040, 1041, 1042]],
133-
token_type_ids: [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
134-
attention_mask: [[1, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
130+
input_ids: [
131+
[1037, 0, 0, 0, 0],
132+
[1038, 1039, 1040, 1041, 1042],
133+
],
134+
token_type_ids: [
135+
[0, 0, 0, 0, 0],
136+
[0, 0, 0, 0, 0],
137+
],
138+
attention_mask: [
139+
[1, 0, 0, 0, 0],
140+
[1, 1, 1, 1, 1],
141+
],
135142
};
136143
compare(output, expected);
137144
});
@@ -161,48 +168,75 @@ describe("Tokenizer padding/truncation", () => {
161168
return_tensor: false,
162169
});
163170
const expected = {
164-
input_ids: [[1037, 0, 0], [1038, 1039, 1040]],
165-
token_type_ids: [[0, 0, 0], [0, 0, 0]],
166-
attention_mask: [[1, 0, 0], [1, 1, 1]],
171+
input_ids: [
172+
[1037, 0, 0],
173+
[1038, 1039, 1040],
174+
],
175+
token_type_ids: [
176+
[0, 0, 0],
177+
[0, 0, 0],
178+
],
179+
attention_mask: [
180+
[1, 0, 0],
181+
[1, 1, 1],
182+
],
167183
};
168184
compare(output, expected);
169185
});
170186

171187
test("Padding 'max_length' without truncation, max_length=3", () => {
172188
const output = tokenizer(inputs_2, {
173-
padding: 'max_length',
189+
padding: "max_length",
174190
truncation: false,
175191
max_length: 3,
176192
add_special_tokens: false,
177193
return_tensor: false,
178194
});
179195
const expected = {
180-
input_ids: [[1037, 0, 0], [1038, 1039, 1040, 1041, 1042]],
181-
token_type_ids: [[0, 0, 0], [0, 0, 0, 0, 0]],
182-
attention_mask: [[1, 0, 0], [1, 1, 1, 1, 1]],
196+
input_ids: [
197+
[1037, 0, 0],
198+
[1038, 1039, 1040, 1041, 1042],
199+
],
200+
token_type_ids: [
201+
[0, 0, 0],
202+
[0, 0, 0, 0, 0],
203+
],
204+
attention_mask: [
205+
[1, 0, 0],
206+
[1, 1, 1, 1, 1],
207+
],
183208
};
184209
compare(output, expected);
185210
});
186211

187212
test("Padding 'max_length' with truncation, max_length=3", () => {
188213
const output = tokenizer(inputs_2, {
189-
padding: 'max_length',
214+
padding: "max_length",
190215
truncation: true,
191216
max_length: 3,
192217
add_special_tokens: false,
193218
return_tensor: false,
194219
});
195220
const expected = {
196-
input_ids: [[1037, 0, 0], [1038, 1039, 1040]],
197-
token_type_ids: [[0, 0, 0], [0, 0, 0]],
198-
attention_mask: [[1, 0, 0], [1, 1, 1]],
221+
input_ids: [
222+
[1037, 0, 0],
223+
[1038, 1039, 1040],
224+
],
225+
token_type_ids: [
226+
[0, 0, 0],
227+
[0, 0, 0],
228+
],
229+
attention_mask: [
230+
[1, 0, 0],
231+
[1, 1, 1],
232+
],
199233
};
200234
compare(output, expected);
201235
});
202236

203237
test("Padding 'max_length' without truncation and max_length=null", () => {
204238
const output = tokenizer(inputs_2, {
205-
padding: 'max_length',
239+
padding: "max_length",
206240
truncation: false,
207241
max_length: null,
208242
add_special_tokens: false,
@@ -211,23 +245,22 @@ describe("Tokenizer padding/truncation", () => {
211245
const expected = {
212246
input_ids: [
213247
[1037, ...Array(511).fill(0)],
214-
[1038, 1039, 1040, 1041, 1042, ...Array(507).fill(0)]
248+
[1038, 1039, 1040, 1041, 1042, ...Array(507).fill(0)],
215249
],
216250
token_type_ids: [
217251
[0, ...Array(511).fill(0)],
218-
[0, 0, 0, 0, 0, ...Array(507).fill(0)]
252+
[0, 0, 0, 0, 0, ...Array(507).fill(0)],
219253
],
220254
attention_mask: [
221255
[1, ...Array(511).fill(0)],
222-
[1, 1, 1, 1, 1, ...Array(507).fill(0)]
256+
[1, 1, 1, 1, 1, ...Array(507).fill(0)],
223257
],
224258
};
225259
compare(output, expected);
226260
});
227261
});
228262

229263
describe("return_tensor=true", () => {
230-
231264
test("throws error when tensor output is requested for a jagged array", () => {
232265
expect(() => tokenizer(inputs)).toThrow("Unable to create tensor");
233266
});
@@ -329,7 +362,7 @@ describe("Tokenizer padding/truncation", () => {
329362

330363
test("padding:'max_length' pads to the specified max_length", () => {
331364
const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, {
332-
padding: 'max_length',
365+
padding: "max_length",
333366
truncation: true,
334367
add_special_tokens: false,
335368
max_length: 3,
@@ -347,7 +380,7 @@ describe("Tokenizer padding/truncation", () => {
347380
[0n, 0n, 0n],
348381
]);
349382
});
350-
})
383+
});
351384
});
352385

353386
describe("Token type ids", () => {

tests/utils/logits_process.test.js

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,17 @@ describe("Logits Processors", () => {
3535
async () => {
3636
const text_input = "hello";
3737

38-
const generated_text_target = " Bert explicit wed digasset";
38+
const generated_text_target = "\uff0d Giuseppeitte natoud";
3939
const text_target = [{ generated_text: text_input + generated_text_target }];
4040

4141
const output = await pipe(text_input, {
4242
max_new_tokens: 5,
4343
bad_words_ids: [
44-
// default: [22172n, 18547n, 8136n, 16012n, 28064n, 11361n]
44+
// default: [1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n]
4545
[18547],
4646

47-
// block #1: [22172n, 16662n, 6261n, 18916n, 29109n, 799n]
48-
[6261, 18916],
47+
// block #1: [1n, 22172n, 31583n, 18824n, 16621n, 8136n, 16012n]
48+
[18824, 16621],
4949
],
5050
});
5151
compare(output, text_target);
@@ -58,22 +58,22 @@ describe("Logits Processors", () => {
5858
async () => {
5959
const text_input = "hello";
6060

61-
const generated_text_target = "erdingsdeletearus)?nor";
61+
const generated_text_target = "erdingsdelete войsequ族";
6262
const text_target = [{ generated_text: text_input + generated_text_target }];
6363

6464
// Construct long list of bad words
6565
const bad_words_ids = [];
66-
// default: [22172n, 18547n, 8136n, 16012n, 28064n, 11361n]
66+
// default: [1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n]
6767
for (let i = 0; i < 100000; ++i) {
6868
bad_words_ids.push([i * 2]); // block all even numbers
6969
}
70-
// block #1: [22172n, 18547n, 8143n, 30327n, 20061n, 18193n]
70+
// block #1: [1n, 22172n, 18547n, 8143n, 30327n, 624n, 2806n, 2004n]
7171
bad_words_ids.push([8143, 30327]);
7272

73-
// block #2: [22172n, 18547n, 8143n, 29485n, 3799n, 29331n]
73+
// block #2: [1n, 22172n, 18547n, 8143n, 29485n, 3799n, 29331n]
7474
bad_words_ids.push([18547, 8143, 29485]);
7575

76-
// block #3: [22172n, 18547n, 8143n, 26465n, 6877n, 15459n]
76+
// block #3: [1n, 22172n, 18547n, 8143n, 7587n, 6831n, 30999n]
7777
const output = await pipe(text_input, { max_new_tokens: 5, bad_words_ids });
7878
compare(output, text_target);
7979
},
@@ -85,19 +85,19 @@ describe("Logits Processors", () => {
8585
async () => {
8686
const text_input = "this is a test";
8787

88-
const generated_text_target = "кт México constructed lake user";
88+
const generated_text_target = "кт México constructed lake års";
8989
const text_target = [{ generated_text: text_input + generated_text_target }];
9090

9191
const output = await pipe(text_input, {
9292
max_new_tokens: 5,
9393
bad_words_ids: [
94-
// default: [445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 16426n]
94+
// default: [1n, 445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 31252n]
9595
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3931], // should never trigger (longer than input sequence)
9696

97-
// block #1: [445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 16426n]
97+
// block #1: [1n, 445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 31252n]
9898
[3931, 14756, 7811],
9999

100-
// result: [445n, 338n, 263n, 1243n, 3931n, 14756n, 13319n, 19437n, 1404n]
100+
// result: [1n, 445n, 338n, 263n, 1243n, 3931n, 14756n, 13319n, 19437n, 21948n]
101101
],
102102
});
103103
compare(output, text_target);

0 commit comments

Comments
 (0)