-
Notifications
You must be signed in to change notification settings - Fork 31
Fix gpt-tfjs bugs, add tests and refactor code #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9c72f72
133ce56
94d664a
b9e1edc
6265316
c3134be
20700f3
0b27fe7
7fa3c12
37fbbed
f9f932d
1f9271a
0ed7a1f
433959b
2e6e834
2faae1c
d808d8e
5eba6b2
cb40e6b
73300b6
d287fd3
1650349
1677af8
a7655e6
21f371c
6721717
1f0c526
a19fd2b
5c5ecde
a80c403
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| import { TEXT_PREPROCESSING } from './index.js' | ||
| import { expect } from 'chai' | ||
|
|
||
| import type { Task } from '../../../index.js' | ||
| import * as tf from '@tensorflow/tfjs' | ||
|
|
||
| describe('text preprocessing', function () { | ||
| const [tokenize, leftPadding] = TEXT_PREPROCESSING | ||
| // Use a function to create different task object for each test (otherwise the tokenizer gets cached) | ||
| function initMockTask(): Task { | ||
| return { | ||
| id: 'mock-task-id', | ||
| displayInformation: {}, | ||
| trainingInformation: { | ||
| modelID: 'model-id', | ||
| epochs: 1, | ||
| roundDuration: 1, | ||
| validationSplit: 0, | ||
| batchSize: 8, | ||
| scheme: 'local', | ||
| dataType: 'text', | ||
| tokenizer: 'Xenova/gpt2', | ||
| }} | ||
| } | ||
|
|
||
| const text = "Hello world, a bc 1 2345, '? 976. Wikipedia is a free content online encyclopedia written and maintained by a community \n of volunteers, known as Wikipedians. Founded by Jimmy Wales and Larry Sanger on January 15, 2001, Wikipedia is hosted by the Wikimedia Foundation, an American nonprofit organization that employs a staff of over 700 people.[7]" | ||
| const expectedTokens = [15496, 995, 11, 257, 47125, 352, 2242, 2231, 11, 705, 30, 860, 4304, 13, 15312, 318, 257, 1479, 2695, 2691, 45352, 3194, 290, 9456, 416, 257, 2055, 220, 198, 286, 11661, 11, 1900, 355, 11145, 46647, 1547, 13, 4062, 276, 416, 12963, 11769, 290, 13633, 311, 2564, 319, 3269, 1315, 11, 5878, 11, 15312, 318, 12007, 416, 262, 44877, 5693, 11, 281, 1605, 15346, 4009, 326, 24803, 257, 3085, 286, 625, 13037, 661, 3693, 22, 60] | ||
|
|
||
| it('can tokenize text', async () => { | ||
| const { tokens } = await tokenize.apply(Promise.resolve(text), initMockTask()) as { tokens: number[]} | ||
| expect(tokens).to.be.deep.equal(expectedTokens) | ||
| }) | ||
|
|
||
| it('can truncate inputs when tokenizing', async () => { | ||
| const truncationTask = initMockTask() | ||
| truncationTask.trainingInformation.maxSequenceLength = 10 | ||
| const { tokens } = await tokenize.apply(Promise.resolve(text), truncationTask) as { tokens: number[] } | ||
| const expectedLength = truncationTask.trainingInformation.maxSequenceLength + 1 // + 1 because tokenization includes an extra token label for next label prediction | ||
| expect(tokens.length).to.be.equal(expectedLength) | ||
| expect(tokens).to.be.deep.equal(expectedTokens.slice(0, expectedLength)) | ||
| }) | ||
|
|
||
| it('can left pad tokens', async () => { | ||
| // Create a task where output token sequence should all have length 20 | ||
| const paddingTask = initMockTask() | ||
| paddingTask.trainingInformation.maxSequenceLength = 20 | ||
|
|
||
| // Create a token sequence of length 10 | ||
| const tokens = { tokens: [0,1,2,3,4,5,6,7,8,9] } | ||
| const { xs, ys } = await leftPadding.apply(Promise.resolve(tokens), paddingTask) as { xs: tf.Tensor1D, ys: tf.Tensor2D } | ||
| const xsArray = await xs.array() | ||
| const ysArray = await ys.array() | ||
|
|
||
| // Output sequences should have shape (20) and (20, 50258), 50258 being the size of the vocab for gpt2 | ||
| expect(xsArray.length).to.be.equal(paddingTask.trainingInformation.maxSequenceLength) | ||
| expect(ysArray.length).to.be.equal(paddingTask.trainingInformation.maxSequenceLength) | ||
| expect(ysArray[0].length).to.be.equal(50258) | ||
|
|
||
| // xs should be left pad with gpt2's padding token 50256 to be of length 20. | ||
| // We expect the last token of input token sequence (9) to not be included in xs since it doesn't have a next token to be predicted | ||
| const paddingToken = 50256 | ||
| const expectedXs = Array.from({length:11}).map(_ => paddingToken).concat(tokens.tokens.slice(0,9)) | ||
| expect(xsArray).to.be.deep.equal(expectedXs) | ||
|
|
||
| // ys should be a one hot encoding of the next token in xs | ||
| // if the input tokens are [0,1,2,3] then the labels are [1,2,3] which are then one-hot encoded | ||
| // So the sum of each row should be equal to 1 | ||
| const expectedOneHot = Array.from({ length: 20 }).map(_ => 1) | ||
| expect(await ys.sum(-1).array()).to.be.deep.equal(expectedOneHot) | ||
|
|
||
| // In each row, the index of the 1 should be the token id | ||
| const expectedYs = Array.from({length:10}).map(_ => paddingToken).concat(tokens.tokens) | ||
| expect(await ys.argMax(-1).array()).to.be.deep.equal(expectedYs) | ||
| }) | ||
|
|
||
| it('throws an error if no tokenizer is specified', async () => { | ||
| const invalidTask = initMockTask() | ||
| invalidTask.trainingInformation.tokenizer = undefined; | ||
| try { | ||
| await tokenize.apply(Promise.resolve("input text doesn't matter"), invalidTask) | ||
| } catch { | ||
| return | ||
| } | ||
| throw new Error("undefined tokenizer should have thrown an error") | ||
| }) | ||
| it('throws an error if the tokenizer name is invalid', async () => { | ||
| const invalidTask = initMockTask() | ||
| invalidTask['trainingInformation']['tokenizer'] = 'invalid-tokenizer-name' | ||
| try { | ||
| await tokenize.apply(Promise.resolve("input text doesn't matter"), invalidTask) | ||
| } catch { | ||
| return | ||
| } | ||
| throw new Error("invalid tokenizer name should have thrown an error") | ||
| }) | ||
|
|
||
| }) |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,16 +7,12 @@ type ModelType = | |||||
| | 'gpt-micro' | ||||||
| | 'gpt-nano' | ||||||
|
|
||||||
| export interface ModelSize { | ||||||
| nLayer?: number | ||||||
| nHead?: number | ||||||
| nEmbd?: number | ||||||
| } | ||||||
|
|
||||||
| export interface GPTConfig { | ||||||
| lr: number | ||||||
| blockSize: number | ||||||
| vocabSize: number | ||||||
| modelType: ModelType | ||||||
| name?: string, | ||||||
| evaluate?: boolean | ||||||
| maxEvalBatches?: number | ||||||
| evaluateEvery?: number | ||||||
|
|
@@ -30,13 +26,16 @@ export interface GPTConfig { | |||||
| embdDrop?: number | ||||||
| tokEmb?: boolean | ||||||
| lmHead?: boolean | ||||||
| modelType: ModelType | ||||||
| nLayer?: number | ||||||
| nHead?: number | ||||||
| nEmbd?: number | ||||||
| } | ||||||
|
|
||||||
| export const DEFAULT_CONFIG: Required<GPTConfig> = { | ||||||
| name: 'transformer', | ||||||
| lr: 0.001, | ||||||
| weightDecay: 0, | ||||||
| maxIter: 10_000, | ||||||
| maxIter: 5, | ||||||
| verbose: 0, | ||||||
| modelType: 'gpt-nano', | ||||||
| evaluate: true, | ||||||
|
|
@@ -50,7 +49,16 @@ export const DEFAULT_CONFIG: Required<GPTConfig> = { | |||||
| residDrop: 0.2, | ||||||
| embdDrop: 0.2, | ||||||
| tokEmb: true, | ||||||
| lmHead: true | ||||||
| lmHead: true, | ||||||
| nLayer: 3, | ||||||
| nHead: 3, | ||||||
| nEmbd: 48, | ||||||
| } | ||||||
|
|
||||||
| export type ModelSize = { | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's equivalent but nicer to use
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was actually searching for pros and cons and didn't find anything significant, why do you prefer interface?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hum, it's a habit mostly, I got it from one of eslint rule. now that I reread the documentation, there aren't much differences. the only plus for interface that I found myself agreeing with it that it shows better error messages |
||||||
| nLayer: number | ||||||
| nHead: number | ||||||
| nEmbd: number | ||||||
| } | ||||||
|
|
||||||
| export function getModelSizes (modelType: ModelType): Required<ModelSize> { | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,44 @@ | ||||
| import { expect } from 'chai' | ||||
| import * as tf from '@tensorflow/tfjs-node' | ||||
| import { AutoTokenizer } from '@xenova/transformers'; | ||||
| import { GPT } from './index.js' | ||||
| import { type GPTConfig } from './config.js' | ||||
|
|
||||
| describe('gpt-tfjs', function() { | ||||
| this.timeout(50_000) | ||||
| const data = "Lorem ipsum dolor sit" | ||||
|
|
||||
| const config: GPTConfig = { | ||||
| modelType: 'gpt-nano', | ||||
| lr: 0.01, | ||||
| maxIter: 10, | ||||
| evaluateEvery:10, | ||||
| maxEvalBatches: 10, | ||||
| blockSize: 8, | ||||
| vocabSize: 50258 | ||||
| } | ||||
|
|
||||
| it('can overfit one sentence', async () => { | ||||
| const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') | ||||
| const datasetSource = new tf.data.FileDataSource(Buffer.from(data)) | ||||
| const textDataset = new tf.data.TextLineDataset(datasetSource) | ||||
| const tokenDataset = textDataset.map((text: string) => { | ||||
| const { input_ids: tokens } = tokenizer(text, { | ||||
| padding: true, | ||||
| truncation: true, | ||||
| return_tensor: false, | ||||
| max_length: config.blockSize + 1, | ||||
| }) as { input_ids: number[] } | ||||
| const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length + 1) | ||||
| const xs = tf.tensor(tokens.slice(0, config.blockSize), undefined, 'int32') | ||||
| return {xs, ys} | ||||
| }).repeat().batch(64) | ||||
|
|
||||
| const model = new GPT(config) | ||||
| const logGenerator = model.train(tokenDataset, undefined, 5) // 5 epochs | ||||
| for await (const _ of logGenerator); // Await the end of training | ||||
| const generation = await model.generate("Lorem ipsum dolor", tokenizer, 1) | ||||
| console.log(generation) | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| expect(generation).equal(data) // Assert that the model completes 'Lorem ipsum dolor' with 'sit' | ||||
| }) | ||||
| }) | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does nothing as
xis a Promise and do not throw created Error.