-
Notifications
You must be signed in to change notification settings - Fork 10
/
flat_tokens_c4_a100x1_84m.yaml
62 lines (51 loc) · 2.02 KB
/
flat_tokens_c4_a100x1_84m.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# python -m train --config-name=huggingface_c4_a100x1_84m +paths.model_name=flat_token_84m
training:
seed: 0
tokens:
batch: 64
len: 1024
# AdamW optimizer parameters
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_eps_root: 0. # A small constant applied to denominator inside the square root.
weight_decay: 0.1 # AdamW Weight decay
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
# Learning rate schedule has two parts:
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
warmup_steps: 2600
steps: 26000
steps_for_lr: 26000
learning_rate: 3.0e-4
cosine_learning_rate_final_fraction: 0.1
model:
d_model: 512
n_q_per_kv: 1
n_kv: 8
d_head: 128
layers: 8
d_ff: 4096
vocab: 32768
rope_max_timescale: 10000
paths:
# can also be a path to GCS. IE 'gcs://your_bucket/your_output_path'
root_working_dir: '~/seqax_outputs'
num_hosts: 1
io:
max_io_threads: 1024
# Define either hf_dataset or flat_tokens. Do not use both.
# flat_tokens requires more setup, but is better tested and doesn't waste tokens.
# Using flat_tokens requires setting up a flat_tokens dataset using the script in tools/huggingface_to_flat_tokens.py
flat_tokens:
filespec: 'gcs://path/to/your/dataset' # can be a path to a gcs directory, or local copy of dataset.
streams: 1
read_blocks_per_shuffle_buffer: 128
sequences_per_read_block: 1024
seed: 0
sequence_packing: true
mesh:
d: 1
t: 1
checkpoint_interval: 100