-
Notifications
You must be signed in to change notification settings - Fork 31
Fix and rework GPT-TF.js #807
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
33b9fcf
docs/CONTRIBUTING: add documentation on debug statements and cypress
JulienVig b31485f
discojs/src/default_tasks: change default batch size and block size
JulienVig 52663c5
cli: add train_gpt script
JulienVig a72a3c6
discojs/src/models/gpt/models: fix loss averaging across iterations
JulienVig 0a219e0
discojs & cli: change gpt2 vocab size to 50257 instead of 50258
JulienVig 04dd6c3
discojs/src/models/gpt: allow model init config to be partial
JulienVig 0370a20
discojs/src/models/gpt/index: link source repo
JulienVig 340e171
discojs/src/models/gpt: always use the token embeddings, a language m…
JulienVig 40d51d7
discojs/src/models/gpt: generation code: clean, document and improve.…
JulienVig f8ede3a
discojs/src/models/gpt/layers: document tensor operations, rename lay…
JulienVig 84cee62
discojs/src/models/gpt/layers: share weights between token embeddings…
JulienVig d3c0689
discojs/src/models/gpt/layers: fix weight initializations
JulienVig 884e14e
discojs/src/models/gpt: add seed, loss is now identical between runs
JulienVig ed73afb
discojs/src/task/training_information: make maxSequenceLength a requi…
JulienVig 9af6b1f
discojs/src/default_tasks/wikitext: use task's maxSequenceLength in t…
JulienVig 8806b7a
*: replace line by line text loaders by chunk by chunk text loaders
JulienVig 40eb86c
discojs*: rename .unbatch() to .flatten()
JulienVig 30de4fb
discojs*,cli*: rename blockSize and maxSequenceLength to contextLength
JulienVig File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| name: record-cypress | ||
| on: | ||
| workflow_dispatch: | ||
|
|
||
| permissions: | ||
| contents: read | ||
|
|
||
| jobs: | ||
| download-datasets: | ||
| runs-on: ubuntu-latest | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| with: | ||
| lfs: true | ||
| submodules: true | ||
| - uses: actions/cache@v4 | ||
| with: | ||
| path: datasets | ||
| key: datasets-${{ hashFiles('datasets/**') }} | ||
| - run: datasets/populate | ||
|
|
||
| build-lib: | ||
| runs-on: ubuntu-latest | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| - uses: actions/setup-node@v4 | ||
| with: | ||
| node-version-file: .nvmrc | ||
| cache: npm | ||
| - run: npm ci | ||
| - run: npm --workspace=discojs run build | ||
|
|
||
| build-lib-web: | ||
| needs: build-lib | ||
| runs-on: ubuntu-latest | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| - uses: actions/setup-node@v4 | ||
| with: | ||
| node-version-file: .nvmrc | ||
| cache: npm | ||
| - run: npm ci | ||
| - run: npm run --workspace=discojs build | ||
| - run: npm run --workspace=discojs-web build | ||
|
|
||
| record-test-webapp: | ||
| needs: [build-lib, build-lib-web, download-datasets] | ||
| runs-on: ubuntu-latest | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| with: | ||
| lfs: true | ||
| submodules: true | ||
| - uses: actions/cache@v4 | ||
| with: | ||
| path: datasets | ||
| key: datasets-${{ hashFiles('datasets/**') }} | ||
| - uses: actions/setup-node@v4 | ||
| with: | ||
| node-version-file: .nvmrc | ||
| cache: npm | ||
| - run: npm ci | ||
| - run: npm --workspace={discojs,discojs-web} run build | ||
| - run: npm --workspace=webapp run test:unit | ||
| - uses: cypress-io/github-action@v6 | ||
| with: | ||
| working-directory: webapp | ||
| install: false | ||
| start: npm start | ||
| wait-on: 'http://localhost:8081' # Waits for above | ||
| # Records to Cypress Cloud | ||
| # https://docs.cypress.io/guides/cloud/projects#Set-up-a-project-to-record | ||
| record: true | ||
| env: | ||
| VITE_SERVER_URL: http://server | ||
| CYPRESS_RECORD_KEY: ${{ secrets.CYPRESS_RECORD_KEY }} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import "@tensorflow/tfjs-node" | ||
| import { AutoTokenizer } from "@xenova/transformers"; | ||
| import { models, processing, Dataset } from "@epfml/discojs"; | ||
| import { List } from "immutable"; | ||
|
|
||
| async function main(): Promise<void> { | ||
| const data = "Lorem ipsum dolor sit amet, consectetur adipis" | ||
| const seed = 42 | ||
|
|
||
| const config: models.GPTConfig = { | ||
| modelType: 'gpt-nano', | ||
| lr: 0.01, | ||
| maxIter: 50, | ||
| evaluateEvery:50, | ||
| maxEvalBatches: 10, | ||
| contextLength: 16, | ||
| seed | ||
| } | ||
|
|
||
| const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') | ||
|
|
||
| const tokenDataset = new Dataset([data]) | ||
| .map((text: string) => processing.tokenize(tokenizer, text)) | ||
| .flatten() | ||
| .batch(config.contextLength + 1, 1) | ||
| .map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number]) | ||
| .repeat() | ||
| .batch(8); | ||
|
|
||
| const model = new models.GPT(config) | ||
| for await (const logs of model.train(tokenDataset, undefined)) { | ||
| console.log(logs) | ||
| } | ||
|
|
||
| let tokens = processing.tokenize(tokenizer, "Lorem"); | ||
|
|
||
| const maxNewTokens = 14 | ||
| for (let n = 0; n < maxNewTokens; n++) { | ||
| const next: number = (await model.predict( | ||
| List.of(tokens), { seed }) | ||
| ).first(); | ||
| tokens = tokens.push(next) | ||
| } | ||
| const generation = tokenizer.decode(tokens.toArray(), { skip_special_tokens: true }) | ||
| console.log(generation) | ||
| } | ||
|
|
||
| // You can run this example with "npm run run_gpt" from this folder | ||
| main().catch(console.error) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,26 @@ | ||
| import * as fs from "node:fs/promises"; | ||
| import * as readline from "node:readline/promises"; | ||
|
|
||
| import createDebug from "debug"; | ||
| import { createReadStream } from 'node:fs'; | ||
| import { Dataset, Text } from "@epfml/discojs"; | ||
|
|
||
| const debug = createDebug("discojs-node:loaders:text"); | ||
|
|
||
| /** | ||
| * Returns chunks of text. Use `minChunkSize` to ensure that | ||
| * each chunk is bigger than the expected sequence length. | ||
| * | ||
| * @param path path to the text file to read | ||
| * @returns a dataset of tokenized input and label sequences | ||
| */ | ||
| export function load(path: string): Dataset<Text> { | ||
| return new Dataset(async function* () { | ||
| const input = (await fs.open(path)).createReadStream({ encoding: "utf8" }); | ||
| // Create a stream to read the text file chunk by chunk | ||
| const stream = createReadStream(path, { encoding: "utf8" }); | ||
| for await (const chunk of stream) { | ||
| if (typeof chunk !== 'string') | ||
| throw new Error('Expected file stream to yield string') | ||
|
|
||
| // `readline` is a bit overkill but seems standard | ||
| // https://nodejs.org/api/readline.html#example-read-file-stream-line-by-line | ||
| yield* readline.createInterface({ input, crlfDelay: Infinity }); | ||
| debug("yield chunk of length: %o", chunk.length); | ||
| yield chunk | ||
| } | ||
| }); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.