Skip to content

Commit 9996f34

Browse files
authored
Add llama 3 tokenizer (apple#850)
* Add llama 3 tokenizer add a new version called V3_TIKTOKEN. other edits based on suggestions. * Handle special tokens like other vocabularies. * use encode instead of encode_batch
1 parent ad14de3 commit 9996f34

File tree

85 files changed

+4629
-334
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+4629
-334
lines changed

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ model.decoder.transformer.num_layers: 16
264264
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
265265
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
266266
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
267-
model.decoder.vocab_size: 128256
267+
model.decoder.vocab_size: 131072
268268
model.dtype: 'jax.numpy.float32'
269269
model.klass: 'axlearn.common.causal_lm.Model'
270270
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
22
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
33
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
44
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ model.decoder.transformer.num_layers: 16
264264
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
265265
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
266266
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
267-
model.decoder.vocab_size: 128256
267+
model.decoder.vocab_size: 131072
268268
model.dtype: 'jax.numpy.float32'
269269
model.klass: 'axlearn.common.causal_lm.Model'
270270
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
22
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
33
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
44
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ model.decoder.transformer.num_layers: 16
231231
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
232232
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
233233
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
234-
model.decoder.vocab_size: 128256
234+
model.decoder.vocab_size: 131072
235235
model.dtype: 'jax.numpy.float32'
236236
model.klass: 'axlearn.common.causal_lm.Model'
237237
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
22
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
33
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
44
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt

Lines changed: 284 additions & 0 deletions
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
2+
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
3+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
4+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
5+
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
6+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
7+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
8+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
9+
decoder/output_norm/scale: constant(1.0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
====================weight_decay_scale root.optimizer====================
2+
decoder/emb/token_emb/weight: 1
3+
decoder/output_norm/scale: 1
4+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
5+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
6+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
7+
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
8+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
9+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
10+
decoder/transformer/repeat/layer/self_attention/norm/scale: 1

0 commit comments

Comments
 (0)