-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathindex.ts
63 lines (54 loc) · 1.49 KB
/
index.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import * as tf from '@tensorflow/tfjs-node'
import { CONFIG, GPT, CharDataset, Trainer } from '@gpt/model'
async function start() {
const backend = tf.getBackend()
console.log(`Current backend: ${backend}`)
const textSourceURL = 'https://raw.githubusercontent.com/trekhleb/homemade-gpt-js/refs/heads/main/playground-web/public/dataset-tinyshakespeare.txt'
const dataset = await CharDataset({ textSourceURL })
const batchSize = 8
const blockSize = 8
const maxIters = 1800
const evalInterval = 300
const evalIterations = 50
const learningRate = 1e-3
const model = GPT({
...CONFIG['gpt-pico'],
blockSize,
vocabSize: dataset.vocabSize,
})
console.log('\nModel summary:', model.summary())
console.log('\nStart training:')
const trainer = Trainer({
model,
dataset,
params: {
learningRate,
evalInterval,
evalIterations,
maxIters,
batchSize,
blockSize,
},
callbacks: {
onEval: (params) => {
console.log(params)
},
},
})
await trainer.train()
console.log('\nStart generation:')
const generated = await model.generate({
idx: tf.ones([1, 1], 'int32'),
maxNewTokens: 500,
doSample: true,
topK: undefined,
})
console.log(dataset.decode(((await generated.array()) as number[][])[0]))
console.log('\nDisposing the model and dataset')
dataset.dispose()
generated.dispose()
model?.dispose?.()
console.log('\nDebug memory consumption:')
console.table(tf.memory())
}
start()