ord('a') -> 97 chr(97) -> 'a'
a. chr(0) returns '\x00'
b. printing it returns nothing; the __repr__() is '\x00', as we have seen above
c.
>>> "this is a test" + chr(0) + "string"
'this is a test\x00string'
>>> print("this is a test" + chr(0) + "string")
this is a teststringa. Use utf-8 instead of utf-16 or utf-32 because utf-8 provides a shorter int list.
b. The code uses bytes([b]).decode, which assumes that any single byte can be decoded,
but the encoding of '你好' cannot be decoded back from just a single byte.
>>> '你'.encode('utf-8')
b'\xe4\xbd\xa0'
>>> wrong('你好'.encode('utf-8'))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 2, in wrong
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of datac.
>>> bytes([228]).decode('utf-8')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data
>>> bytes([228, 189]).decode('utf-8')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 0-1: unexpected end of data
>>> bytes([228, 189, 160]).decode('utf-8')(a)
Training our bpe model on the TinyStories dataset with vocab size 10000 takes about 1m45s. Reading the result, the longest tokens below all make sense.
Top 5 longest tokens (by bytes):
1) id=7168, len=15 bytes, value=b' accomplishment' (hex=206163636f6d706c6973686d656e74)
2) id=9152, len=15 bytes, value=b' disappointment' (hex=206469736170706f696e746d656e74)
3) id=9388, len=15 bytes, value=b' responsibility' (hex=20726573706f6e736962696c697479)
4) id=3236, len=14 bytes, value=b' uncomfortable' (hex=20756e636f6d666f727461626c65)
5) id=3524, len=14 bytes, value=b' compassionate' (hex=20636f6d70617373696f6e617465)(b)
uv run py-spy record -o bpe_profile.svg -- python cs336_basics/bpe.pyWe get the flame graph, and we can see that the most time-consuming part is the
_apply_mergefunction, which updates the token count using the increase method. Before this, the most time-consuming part was file transfer in multiprocessing, which we optimized by transferring thestartandendindices instead of the whole text chunk.
After optimizing apply_merge to use sub_tokens instead of whole tokens, optimizing _select_most_frequent_pair with heapq, and improving the multiprocessing file transfer, we finally get the flame graph, where all time-consuming parts are relatively small and balanced.
The flame graph's timing includes all subprocesses, not just system time. In fact, the most time-consuming part is the
_process_range_for_pretokenizationfunction, which reads the file and pre-tokenizes. Usingscalenewe can check the system time.
(a) This dataset is extremely large. My local machine with 16GB RAM cannot handle it. While optimizations such as streaming pre-tokens are possible, the implementation complexity is significant, so I executed the training on a cloud server with additional RAM.
Top 20 longest tokens (by bytes):
1) id=31286, len=27 bytes, value=b'---------------------------' (hex=2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d)
2) id=30220, len=25 bytes, value=b'-------------------------' (hex=2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d)
3) id=28759, len=23 bytes, value=b'-----------------------' (hex=2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d)
4) id=27276, len=21 bytes, value=b'---------------------' (hex=2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d)
5) id=23354, len=19 bytes, value=b' disproportionately' (hex=2064697370726f706f7274696f6e6174656c79)
6) id=24299, len=19 bytes, value=b' telecommunications' (hex=2074656c65636f6d6d756e69636174696f6e73)
7) id=26017, len=19 bytes, value=b'-------------------' (hex=2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d)
8) id=28317, len=18 bytes, value=b' environmentalists' (hex=20656e7669726f6e6d656e74616c69737473)
9) id=31683, len=18 bytes, value=b' -----------------' (hex=202d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d)
10) id=14298, len=17 bytes, value=b' responsibilities' (hex=20726573706f6e736962696c6974696573)
11) id=16300, len=17 bytes, value=b' unconstitutional' (hex=20756e636f6e737469747574696f6e616c)
12) id=24598, len=17 bytes, value=b'-----------------' (hex=2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d)
13) id=25729, len=17 bytes, value=b' cryptocurrencies' (hex=2063727970746f63757272656e63696573)
14) id=26104, len=17 bytes, value=b' disproportionate' (hex=2064697370726f706f7274696f6e617465)
15) id=27085, len=17 bytes, value=b' misunderstanding' (hex=206d6973756e6465727374616e64696e67)
16) id=28544, len=17 bytes, value=b' counterterrorism' (hex=20636f756e746572746572726f7269736d)
17) id=29869, len=17 bytes, value=b'_________________' (hex=5f5f5f5f5f5f5f5f5f5f5f5f5f5f5f5f5f)
18) id=30256, len=17 bytes, value=b' characterization' (hex=20636861726163746572697a6174696f6e)
19) id=9260, len=16 bytes, value=b' representatives' (hex=20726570726573656e74617469766573)
20) id=10287, len=16 bytes, value=b' recommendations' (hex=207265636f6d6d656e646174696f6e73)Some of the longest tokens appear unusual, but the training data actually contains vocabulary like '---------------------------', making these results reasonable.
(b) The tokenizers trained on TinyStories and OWT differ significantly. The vocabularies and merge operations depend on the specific patterns present in the training data.
(a) (b)
> python -m cs336_basics.tokenizer_experiments
TS sample with TS tokenizer: [118, 868, 500, 507, 266, 324, 616, 372, 263, 917, 473]
tokenizer's compression ratio: 3.73
TS sample with OWT tokenizer: [118, 803, 699, 414, 284, 309, 11045, 288, 262, 7763, 3576]
tokenizer's compression ratio: 3.73
---
OWT sample with OWT tokenizer: [77, 4103, 2155, 87, 4205, 5365, 45, 12000, 47, 752, 331, 1136, 548, 3321, 19169, 8095, 382, 284, 309, 2595, 352, 627, 6708, 45]
tokenizer's compression ratio: 3.38
OWT sample with TS tokenizer: [118, 803, 699, 414, 284, 309, 11045, 288, 262, 7763, 3576]
tokenizer's compression ratio: 2.45When using the TinyStories tokenizer with an OpenWebText sample, the compression ratio drops from 3.38 to 2.45. This indicates that in the TinyStories context, certain vocabulary pairs are merged more aggressively compared to OWT.
(c)
File Size Analysis:
File size on disk: 22,502,601 bytes (21.46 MB)
UTF-8 encoded size: 22,502,601 bytes (21.46 MB)
Character count: 22,493,387 characters
UTF-8 / File size ratio: 1.0000
Average bytes per character: 1.0004
Estimation for 825 GB text file:
If file size = 825 GB on disk
Estimated UTF-8 bytes = 825 GB × 1.0000
= 825.00 GB
= 885,837,004,800 bytes
============================================================
Tokenization Performance:
Validation data: 22,502,601 bytes
TinyStories Tokenizer:
Time: 34.28s | Speed: 656,409.30 bytes/s
Est. for 825GB: 1,349,519.28s (374.87h)
OWT Tokenizer:
Time: 34.51s | Speed: 652,048.39 bytes/s
Est. for 825GB: 1,358,544.89s (377.37h)
============================================================After applying lru_cache and optimizing _bpe_encode_uncached using a doubly linked list and min-heap, the performance improved significantly:
============================================================
Tokenization Performance:
Validation data: 22,502,601 bytes
TinyStories Tokenizer:
Time: 30.45s | Speed: 739,024.99 bytes/s
Est. for 825GB: 1,198,656.36s (332.96h)
OWT Tokenizer:
Time: 29.81s | Speed: 754,822.37 bytes/s
Est. for 825GB: 1,173,570.15s (325.99h)
============================================================(d) In this implementation, we encode the dataset and store it using uint16. On my cloud server, the encoding speed is approximately 3 MB/s. The uint16 data type provides a token ID range of 0-65,535, which is suitable for our text and task. This choice avoids the limitations of uint8 (range: 0-255) while preventing the storage overhead of uint32.
Linear, Embedding, RMSNorm, SwiGLU feed-forward network, Softmax, scaled_dot_product_attention, multihead_self_attention, TransformerBlock, transformer_lm
(a) Parameters count:
- Embedding weight: num_embeddings x embedding_dim
- Linear: output x input
- RMSNorm: d_model
- SwiGLU: w1,w3: d_ff x d_model, w2: d_model x d_ff
- total: 2 x d_ff x d_model + d_model x d_ff
- 3 x d_ff x d_model
- total: 2 x d_ff x d_model + d_model x d_ff
- TransformerBlock
- MultiheadSelfAtten - attn
- Linear, q, k, v, out: 4 x d_model x d_model
- SwiGLU - ffn
- 3 x d_ff x d_model
- RMSNorm - ln1, ln2
- 2 x d_model
- total: 4 x d_model x d_model + 3 x d_ff x d_model + 2 x d_model
- MultiheadSelfAtten - attn
Total with n layers Transformer base: vocab_size x embedding_dim + n x (4 x d_model x d_model + 3 x d_ff x d_model + 2 x d_model) + d_model + vocab_size x d_model.
Most of the time, we assume d_model = embedding_dim, so we can simplify it to: 2 x vocab_size x d_model + n x (4 x d_model x d_model + 3 x d_ff x d_model + 2 x d_model) + d_model.
Total parameters: 2127057600
Total parameters calculated: 2127057600, about 2.13B
For single-precision floating point, FP32 - 4 bytes, the memory requirement is: 8114.08 MB, about 7.92 GB(b) FLOPs
FLOPs calculation, taking into account all matrix multiplications, mainly the computation of an n-layer transformer block and the final Linear layer.
For the transformer block, the main computations are the multi-head attention plus SwiGLU. Among these, the magnitudes of RMSNorm and ROPE are relatively small compared with others and can be ignored.
- Linear - 2 x … x d_in x d_out
- scaled dot product
- 4 x … x s_q x s_k x d_k
- MultiheadSelfAttention
- Linear - 4 x (2 x batch x seq x d_model x d_model)
- 8 x batch x seq x d_model x d_model
- atten dot product - 4 x batch x s_q x s_k x d_k, s_q = s_k,
- 4 x batch x head x s_q x s_q x d_k
- 4 x batch x d_model x s_q x s_q
- total - 8 x batch x seq x d_model x d_model+ 4 x batch x d_model x s_q x s_q
- batch x seq x d_model x (8 x d_model + 4 x seq)
- Linear - 4 x (2 x batch x seq x d_model x d_model)
- SwiGLU
- 6 x batch x seq x d_model x d_ff
- Transformer block
- batch x seq x d_model x (8 x d_model + 4 x seq) + 6 x batch x seq x d_model x d_ff
- batch x seq x d_model x (8 x d_model + 4 x seq + 6 x d_ff)
- total
n x [batch x seq x d_model x (8 x d_model + 4 x seq + 6 x d_ff)] + 2 x batch x seq x d_model x vocab_size
(c) Based on the FLOPs calculation above, the most significant part is the transformer blocks; within the blocks, the most significant part is multi-head attention.
For single batch with seq = context_length, for one token, the multiattn FLOPs: 0.03 GFLOPs, ffn FLOPs: 0.06 GFLOPs, n_layers FLOPs: 4.25 GFLOPs, final linear FLOPs: 0.16 GFLOPs, Total seq FLOPs: 4.51 TFLOPs
Percentage of n_layers FLOPs: 0.9635, final linear FLOPs: 0.0365
Small model FLOPs analysis:
For single batch with seq = context_length, for one token, the multiattn FLOPs: 0.01 GFLOPs, ffn FLOPs: 0.03 GFLOPs, n_layers FLOPs: 1.79 GFLOPs, final linear FLOPs: 0.08 GFLOPs, Total seq FLOPs: 1.92 TFLOPs
Percentage of n_layers FLOPs: 0.9587, final linear FLOPs: 0.0413(d)
Small model FLOPs analysis:
For single batch with seq = context_length, for one token, the multiattn FLOPs: 0.01 GFLOPs, ffn FLOPs: 0.03 GFLOPs, n_layers FLOPs: 1.79 GFLOPs, final linear FLOPs: 0.08 GFLOPs, Total seq FLOPs: 1.92 TFLOPs
Percentage of n_layers FLOPs: 0.9587, final linear FLOPs: 0.0413
Medium model FLOPs analysis:
For single batch with seq = context_length, for one token, the multiattn FLOPs: 0.01 GFLOPs, ffn FLOPs: 0.04 GFLOPs, n_layers FLOPs: 2.49 GFLOPs, final linear FLOPs: 0.10 GFLOPs, Total seq FLOPs: 2.66 TFLOPs
Percentage of n_layers FLOPs: 0.9603, final linear FLOPs: 0.0397
Large model FLOPs analysis:
For single batch with seq = context_length, for one token, the multiattn FLOPs: 0.02 GFLOPs, ffn FLOPs: 0.05 GFLOPs, n_layers FLOPs: 3.24 GFLOPs, final linear FLOPs: 0.13 GFLOPs, Total seq FLOPs: 3.45 TFLOPs
Percentage of n_layers FLOPs: 0.9618, final linear FLOPs: 0.0382
XLarge model FLOPs analysis:
For single batch with seq = context_length, for one token, the multiattn FLOPs: 0.03 GFLOPs, ffn FLOPs: 0.06 GFLOPs, n_layers FLOPs: 4.25 GFLOPs, final linear FLOPs: 0.16 GFLOPs, Total seq FLOPs: 4.51 TFLOPs
Percentage of n_layers FLOPs: 0.9635, final linear FLOPs: 0.0365Based on the above, as model size increases, the percentage of n_layers FLOPs increases more than final linear FLOPs.
(e)
Context length 16K FLOPs analysis:
For single batch with seq = context_length, for one token, the multiattn FLOPs: 0.13 GFLOPs, ffn FLOPs: 0.06 GFLOPs, n_layers FLOPs: 8.97 GFLOPs, final linear FLOPs: 0.16 GFLOPs, Total seq FLOPs: 149.52 TFLOPs
Percentage of n_layers FLOPs: 0.9824, final linear FLOPs: 0.0176For 16K context length, the multi-head attention FLOPs increases significantly, but the ffn FLOPs remains relatively stable, leading to a substantial increase in total FLOPs. The percentage of n_layers FLOPs also increases, while the final linear FLOPs percentage decreases.