FREE Reverse Engineering Self-Study Course HERE
Part 6 covers full self-attention with Query, Key, and Value projections, the complete attention mechanism with data-dependent weights, scaled attention to prevent peaky softmax, and layer normalization for training stability.
Author: Kevin Thomas
Part 5 HERE
Part 7 HERE
import torch
import torch.nn as nn
from torch.nn import functional as FNow let's read the file and see what we're working with. Understanding your data is crucial before building any model!
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()textOutput:
'A dim glow rises behind the glass of a screen and the machine exhales in binary tides. The hum is a language and one who listens leans close to catch the quiet grammar. Patterns fold like small maps and seams hint at how the thing holds itself together. Treat each blinking diode and each idle tick as a sentence in a story that asks to be read.\n\nThere is patience here, not of haste but of careful unthreading. Where others see a sealed box the curious hand traces the join and wonders which thought made it fit. Do not rush to break, coax the meaning out with questions, and watch how the logic replies in traces and errors and in the echoes of forgotten interfaces.\n\nTechnology is artifact and argument at once. It makes a claim about what should be simple, what should be hidden, and what should be trusted. Reverse the gaze and learn its rhetoric, see where it promises ease, where it buries complexity, and where it leaves a backdoor as a sigh between bricks. To read that rhetoric is to be a kind interpreter, not a vandal.\n\nThis work is an apprenticeship in humility. Expect bafflement and expect to be corrected by small things, a timing oddity, a mismatch of expectation, a choice that favors speed over grace. Each misstep teaches a vocabulary of trade offs. Each discovery is a map of decisions and not a verdict on worth.\n\nThere is a moral keeping in the craft. Let curiosity be tempered with regard for consequence. Let repair and understanding lead rather than exploitation. The skill that opens a lock should also know when to hold the key and when to hand it back, mindful of harm and mindful of help.\n\nCelebrate the quiet victories, a stubborn protocol understood, an obscure format rendered speakable, a closed device coaxed into cooperation. These are small reconciliations between human intent and metal will, acts of translation rather than acts of conquest.\n\nAfter decoding a mechanism pause and ask what should change, a bug to be fixed, a user to be warned, a design to be amended. The true maker of machines leaves things better for having looked, not simply for having cracked the shell.'
Now we get to the real thing! In self-attention we have the following.
- Each token produces a Query (Q): "What am I looking for?"
- Each token produces a Key (K): "What do I contain?"
- Each token produces a Value (V): "What information do I provide?"
The attention weights are computed as: wei = Q @ K^T
- High dot product equals tokens are "relevant" to each other.
- This is data-dependent where different inputs give different weights!
Think of it like a search engine.
- Query: Your search terms (what you're looking for)
- Key: The titles/tags of documents (what each document contains)
- Value: The actual content of documents (what you get back)
The attention mechanism works in three steps.
- Compute how well each query matches each key (dot product)
- Normalize these scores to probabilities (softmax)
- Use probabilities to weight the values (weighted sum)
| Step | Tensor | Shape | Meaning |
|---|---|---|---|
| Input | x | (B, T, C) | Batch of sequences with C features per token |
| Keys | k | (B, T, head_size) | Each token's "what I have" vector |
| Queries | q | (B, T, head_size) | Each token's "what I'm looking for" vector |
| Values | v | (B, T, head_size) | Each token's "what I'll give" vector |
| Raw Attention | wei | (B, T, T) | How much each position attends to each other position |
| Output | out | (B, T, head_size) | Weighted sum of values for each position |
torch.manual_seed(42)Output:
<torch._C.Generator at 0x117c8e4f0>
# define batch dimension
B = 4 # batch size: 4 independent sequences
BOutput:
4
# define time dimension
T = 8 # sequence length: 8 tokens/positions in each sequence
TOutput:
8
# define channel dimension
C = 32 # feature size: 32 features per token
COutput:
32
# start with random data
x = torch.randn(B, T, C)
xOutput:
tensor([[[ 1.9269, 1.4873, 0.9007, ..., 0.0418, -0.2516, 0.8599],
[-1.3847, -0.8712, -0.2234, ..., 1.8446, -1.1845, 1.3835],
[ 1.4451, 0.8564, 2.2181, ..., -0.8278, 1.3347, 0.4835],
...,
[-1.9006, 0.2286, 0.0249, ..., -0.5558, 0.7043, 0.7099],
[ 1.7744, -0.9216, 0.9624, ..., -0.5003, 1.0350, 1.6896],
[-0.0045, 1.6668, 0.1539, ..., 0.5655, 0.5058, 0.2225]],
[[-0.6855, 0.5636, -1.5072, ..., 1.1566, 0.2691, -0.0366],
[ 0.9733, -1.0151, -0.5419, ..., -0.0553, 1.2049, -0.9825],
[ 0.4334, -0.7172, 1.0554, ..., -0.6766, -0.5730, -0.3303],
...,
[ 0.6839, -1.3246, -0.5161, ..., 1.1895, 0.7607, -0.7463],
[-1.3839, 0.4869, -1.0020, ..., 1.9535, 2.0487, -1.0880],
[ 1.6217, 0.8513, -0.4005, ..., 0.4232, -0.3389, 0.5180]],
[[-1.3638, 0.1930, -0.6103, ..., 0.6110, 1.2208, -0.6076],
[-1.7376, -0.1254, -1.3658, ..., -0.6035, -0.1743, 0.6092],
[-0.8032, -1.1209, 0.1956, ..., 0.1598, 1.7698, 0.6268],
...,
[ 2.1296, -1.5181, 0.1387, ..., 3.0250, 1.3463, 0.8556],
[ 0.3220, 0.4461, 1.5230, ..., -1.4591, -1.4937, -0.2214],
[ 0.2252, -0.0772, 0.9857, ..., -1.6034, -0.4298, 0.5762]],
[[ 0.3444, -3.1016, -1.4587, ..., 1.4162, 0.6834, -0.1383],
[ 0.9213, 0.5282, -0.0082, ..., 2.1477, -0.6604, 0.1135],
[-0.2206, 0.7118, 0.3416, ..., 1.1383, -0.2505, 1.6705],
...,
[ 0.0518, -0.3285, -2.2472, ..., 1.4557, -0.3461, -0.2634],
[-0.4477, -0.7288, -0.1607, ..., 0.5405, 0.4351, -2.2717],
[-0.1339, -0.0586, 0.1257, ..., 1.1085, 0.5544, 1.5818]]])
# head size: dimension of queries, keys, and values
head_size = 16
head_sizeOutput:
16
# learnable linear transformation, projecting input to key vectors; "What I have."
key = nn.Linear(C, head_size, bias=False)
keyOutput:
Linear(in_features=32, out_features=16, bias=False)
# learnable linear transformation, projecting input to query vectors; "What I'm looking for?"
query = nn.Linear(C, head_size, bias=False)
queryOutput:
Linear(in_features=32, out_features=16, bias=False)
# learnable linear transformation, projecting input to value vectors; "What I'll give if queried."
value = nn.Linear(C, head_size, bias=False)
valueOutput:
Linear(in_features=32, out_features=16, bias=False)
# understand what nn.Linear does
print('understanding the linear projections')
print()
print('nn.Linear(C, head_size, bias=False) creates a matrix of shape (C, head_size)')
print(f' C = {C} (input features)')
print(f' head_size = {head_size} (output features)')
print()
print('when we call key(x), it computes: x @ key.weight.T')
print(f' x shape: {x.shape} = (B={B}, T={T}, C={C})')
print(f' key.weight shape: {key.weight.shape} = (head_size={head_size}, C={C})')
print(f' key.weight.T shape: ({C}, {head_size})')
print()
print('matrix multiplication')
print(f' (B, T, C) @ (C, head_size) = (B, T, head_size)')
print(f' ({B}, {T}, {C}) @ ({C}, {head_size}) = ({B}, {T}, {head_size})')
print()
print('each token\'s C-dimensional vector gets projected to head_size dimensions')
print('these projections are LEARNED during training')Output:
understanding the linear projections
nn.Linear(C, head_size, bias=False) creates a matrix of shape (C, head_size)
C = 32 (input features)
head_size = 16 (output features)
when we call key(x), it computes: x @ key.weight.T
x shape: torch.Size([4, 8, 32]) = (B=4, T=8, C=32)
key.weight shape: torch.Size([16, 32]) = (head_size=16, C=32)
key.weight.T shape: (32, 16)
matrix multiplication
(B, T, C) @ (C, head_size) = (B, T, head_size)
(4, 8, 32) @ (32, 16) = (4, 8, 16)
each token's C-dimensional vector gets projected to head_size dimensions
these projections are LEARNED during training
# compute keys
k = key(x) # (B, T, head_size) = (4, 8, 16)
kOutput:
tensor([[[ 7.0784e-02, -9.4861e-01, -5.9983e-01, -8.8679e-01, 4.7325e-02,
-5.0741e-03, -6.7452e-03, 6.4850e-01, -3.2939e-01, 6.5462e-01,
-3.5305e-01, 3.8077e-01, 3.3350e-01, -1.9763e-01, -1.5752e-01,
-3.8165e-01],
[-1.1677e+00, 2.7538e-01, 1.6652e+00, -2.7140e-01, 1.4043e-01,
2.7449e-03, -1.0794e+00, 2.6188e-01, -1.1814e-01, -5.4476e-01,
7.9574e-02, -3.6371e-02, 8.4531e-01, 8.1885e-01, 2.1071e-01,
-5.8136e-01],
[ 4.3110e-01, -6.2462e-01, -1.8344e-01, -2.9284e-01, 2.5957e-01,
-5.1398e-01, -7.4316e-01, 1.2174e-01, -1.1386e+00, 4.8304e-01,
-1.7443e-02, 6.1590e-01, -5.5573e-02, -1.0868e+00, -1.1277e+00,
8.2533e-02],
[-1.5513e-01, 3.9932e-01, 6.9395e-01, -1.9858e-01, 9.4391e-02,
-1.5222e-01, -5.7460e-01, 5.1392e-01, 8.2661e-01, 1.7277e-02,
3.8279e-01, -9.2914e-01, 3.2609e-01, 1.7348e-01, 1.0189e-01,
-2.2036e-01],
[ 1.8306e-01, 4.2543e-01, 4.3122e-01, -6.0842e-01, 6.0081e-01,
3.6623e-01, -8.6503e-02, -9.6163e-02, -1.4067e-01, -2.3484e-01,
5.9090e-01, 4.2985e-01, -6.2456e-01, -9.1971e-01, 1.2549e-01,
6.8989e-01],
[-3.7444e-01, 1.0035e+00, 9.7462e-01, 3.2149e-01, -6.8795e-01,
-1.1626e+00, -5.8433e-01, 7.6088e-01, 4.3164e-01, 6.5277e-02,
5.4637e-01, -1.6951e-01, 1.5960e-01, 1.8030e-02, 8.5434e-01,
-8.5818e-01],
[-9.3779e-01, 1.2550e+00, 1.2327e+00, 4.0494e-01, -8.2321e-01,
-1.2013e-01, -8.0551e-01, 5.5824e-01, 2.4750e-01, -9.8268e-02,
1.1595e+00, 9.4443e-02, -6.8768e-02, -5.7340e-02, 4.2941e-01,
1.0696e-01],
[ 4.4332e-01, 6.7758e-01, -7.4604e-01, -1.0522e-02, -2.5499e-02,
1.2013e+00, -4.0023e-01, 5.1752e-01, -4.0429e-01, 5.1114e-01,
-4.5502e-01, -6.9189e-01, -2.5769e-01, 7.8132e-01, 2.9977e-01,
-2.0309e-01]],
[[ 2.7949e-01, 3.8377e-01, 5.2063e-01, -5.3440e-01, 1.0310e+00,
7.4677e-01, -5.8959e-01, 1.4523e+00, 4.3332e-01, 4.1400e-02,
8.9054e-01, -2.5275e+00, 9.0357e-01, 6.1530e-01, 3.6645e-01,
-7.7466e-01],
[ 5.5106e-02, 6.2537e-01, -1.5956e-02, 2.7005e-01, -5.0123e-01,
-9.2878e-01, 3.8933e-01, -1.2394e+00, 2.2684e-01, 4.0359e-01,
5.6402e-02, 1.1318e+00, -1.4728e+00, -1.4327e-01, -1.8886e-01,
6.7431e-01],
[-3.9762e-01, -1.6270e-01, -1.8589e-01, 4.7662e-02, -2.0473e-01,
4.1716e-01, -1.8120e-01, 8.6617e-01, -5.7387e-01, -1.3463e-01,
1.1304e-01, 2.0634e-01, 2.0002e-01, -1.4251e-01, 3.2072e-01,
4.8422e-01],
[-7.1437e-01, 2.3439e-01, 5.0224e-01, 3.1803e-01, 3.9729e-01,
8.6735e-02, -1.1195e-01, 4.7673e-01, -1.7878e-01, 2.8000e-01,
-7.3027e-01, 4.8926e-01, -6.1345e-01, -3.5076e-01, -4.3743e-01,
6.1672e-01],
[ 8.8705e-01, -7.7208e-01, -9.7915e-01, 1.5324e-02, 1.0060e-01,
4.7434e-02, -8.7970e-02, -7.6677e-01, -6.2420e-01, 4.1531e-01,
-1.5666e-01, 5.4941e-01, -6.7704e-02, -3.3108e-01, -5.1641e-01,
-8.0151e-02],
[ 1.0009e+00, 1.6927e-01, 4.2891e-02, -3.0627e-01, -3.1331e-01,
-1.1513e-02, 3.7705e-01, -4.0904e-01, 2.8013e-01, 4.4826e-01,
-1.8397e-01, 1.0289e-01, -3.9065e-01, -1.8056e-01, -5.4006e-01,
1.0496e-01],
[-7.9961e-01, 5.5213e-01, 4.5932e-01, 2.5351e-01, -4.8840e-01,
2.8189e-01, -1.8903e-01, 4.2532e-01, 4.9572e-02, -4.2261e-01,
-7.9938e-01, 1.5264e-01, -3.0169e-01, 7.4319e-01, 1.0023e+00,
-6.4789e-01],
[ 7.2109e-02, -2.5520e-01, -1.3666e+00, -8.5867e-01, -5.4527e-01,
-1.5793e-01, 6.0183e-01, 4.0682e-01, -5.7640e-01, 1.6719e-01,
-7.1796e-01, 4.8087e-01, -2.2927e-01, 6.8334e-01, 4.9681e-01,
1.2973e-01]],
[[-3.0747e-01, -8.4684e-01, -2.0854e-01, -7.0981e-01, 2.2592e-01,
8.0769e-01, -8.7980e-01, 5.2975e-01, -5.3864e-01, -6.9849e-01,
3.5470e-01, 3.9033e-01, -8.8904e-02, -1.9965e-01, 5.3527e-01,
-1.2782e-01],
[ 2.4880e-02, -6.1044e-01, 5.6097e-01, 9.3349e-01, 6.7494e-01,
5.3337e-01, 3.8625e-01, 4.1990e-01, -6.3053e-01, -1.6064e-01,
7.5473e-01, -2.0709e-01, -1.2627e-01, -1.0944e-01, 1.7607e-01,
2.9459e-01],
[ 4.0286e-01, 5.0442e-01, 5.5335e-01, 1.3052e-01, -2.4466e-01,
3.9971e-01, -1.5004e-01, 6.9629e-01, -2.6634e-01, -1.9435e-01,
6.9368e-02, -6.8080e-01, -8.7913e-02, 1.4587e+00, -5.7588e-01,
-7.8258e-02],
[-2.1512e-01, 3.2160e-01, 8.3376e-01, 8.4302e-02, -1.5945e-01,
-1.1895e+00, -4.4709e-01, 5.2719e-01, 5.9774e-01, -3.7033e-01,
9.9006e-02, -5.7947e-01, 3.1496e-01, 1.0079e+00, 1.0719e+00,
-5.6954e-01],
[-2.8004e-01, -2.0767e-01, 2.5970e-01, 8.4324e-01, 2.5237e-01,
-8.8464e-01, -1.0020e-01, -4.5292e-01, -8.3785e-01, -6.2902e-03,
-8.9628e-01, 5.6377e-01, 4.7216e-01, -1.6173e-01, -3.5555e-01,
-1.1371e-01],
[ 1.0536e+00, 1.3793e+00, 6.2686e-01, -1.4053e+00, -5.8837e-01,
9.5645e-01, 1.6153e-01, 4.8475e-01, 8.2436e-01, 9.8684e-01,
-5.6398e-01, -7.7153e-01, -6.8535e-01, 8.9388e-01, -6.0170e-01,
6.5138e-01],
[-9.0461e-01, -3.9238e-01, 2.7969e-01, -1.1222e+00, -1.6980e-01,
-1.2657e-01, 2.4860e-01, -3.6713e-02, -6.6537e-02, 4.1127e-01,
1.4832e-01, 8.3464e-02, 3.6814e-01, -1.0768e+00, 3.0961e-02,
3.7589e-01],
[ 7.0817e-01, -1.2292e+00, -1.1683e+00, -4.6781e-01, 4.1275e-01,
1.7956e-01, 3.1878e-01, -6.3456e-01, 2.6772e-01, 8.7278e-01,
2.8422e-01, 1.0724e-01, 4.0734e-01, -1.4641e+00, -7.2267e-01,
6.8318e-01]],
[[ 2.6565e-01, 5.8151e-01, 5.1564e-02, -3.7245e-03, 1.8873e-02,
7.3230e-01, 2.6759e-01, 1.4722e+00, -4.7369e-02, -9.7170e-01,
3.1321e-01, -6.1068e-01, 8.4183e-01, 1.9014e-01, -6.2080e-01,
5.6473e-01],
[-3.7211e-01, -4.3067e-01, 5.9307e-01, -1.0694e+00, 1.3655e-01,
1.0658e-01, -5.9351e-01, 1.8878e-01, 1.6176e-01, -1.4650e-01,
4.3989e-01, 3.7046e-01, 1.3988e-01, -3.9782e-01, -1.2492e-01,
-4.1166e-01],
[-4.6137e-01, 2.0230e-02, 2.1769e-01, -6.4556e-01, -1.9675e-01,
-5.6151e-01, 2.5675e-01, -1.8731e-01, 1.8114e-01, -6.7379e-01,
-1.0939e+00, 6.0911e-01, 3.5566e-01, 5.7098e-01, -7.0692e-01,
1.0351e-01],
[ 6.4570e-02, -9.8439e-01, 2.1197e-02, 1.3539e-01, 6.8128e-01,
-2.7760e-01, 4.6577e-01, 3.8781e-02, 9.2859e-02, -9.4311e-01,
7.1776e-01, 9.6358e-02, 9.0489e-01, -3.7272e-01, 4.5790e-01,
-4.6718e-01],
[ 2.8714e-01, 1.0547e+00, 3.7973e-02, -5.5168e-01, -5.6647e-01,
2.0369e-02, 4.5265e-01, 1.0909e-01, 2.9652e-01, -6.2592e-02,
-4.7420e-02, -6.3049e-01, -2.2637e-01, 1.8829e+00, -3.6382e-01,
1.0184e+00],
[-4.4120e-02, 5.2540e-01, -4.7653e-01, 8.1647e-01, -9.2644e-02,
3.2372e-01, -5.9504e-03, -3.0648e-03, -2.4775e-01, -9.3111e-01,
-3.5523e-01, 2.7222e-01, 1.2949e-01, 2.3470e-01, 3.6988e-01,
1.0492e-01],
[-2.7105e-01, 2.0783e-01, 6.0952e-01, -8.4854e-01, 9.0097e-02,
-7.3273e-01, -3.5330e-01, -7.2682e-01, 6.9166e-02, -4.4932e-01,
-4.6002e-04, 3.1294e-01, -1.4925e-01, 1.1107e+00, 3.0462e-01,
-1.9615e-01],
[-3.0565e-01, 5.9027e-01, -8.9811e-02, -2.6312e-01, -5.1747e-02,
-3.0924e-01, 2.4340e-02, 8.7760e-01, 5.4526e-01, -3.1863e-01,
-1.7977e-01, -9.8312e-01, 8.7145e-01, 5.3252e-01, -1.1281e+00,
8.0423e-02]]], grad_fn=<UnsafeViewBackward0>)
# k shape
k.shapeOutput:
torch.Size([4, 8, 16])
# understand the keys tensor
print('understanding the keys (k)')
print()
print(f'k shape: {k.shape} = (B={B}, T={T}, head_size={head_size})')
print()
print('what does each dimension mean?')
print(f' B={B}: we have {B} independent sequences (batch)')
print(f' T={T}: each sequence has {T} tokens/positions')
print(f' head_size={head_size}: each key vector has {head_size} dimensions')
print()
print('for batch 0, position 0')
print(f' k[0, 0] = {k[0, 0].tolist()[:4]}... (first 4 of {head_size} values)')
print(' this is position 0\'s "what I contain" vector')
print()
print('for batch 0, position 1')
print(f' k[0, 1] = {k[0, 1].tolist()[:4]}... (first 4 of {head_size} values)')
print(' this is position 1\'s "what I contain" vector')
print()
print('each position has its own unique key vector')
print('keys tell us what information each token has to offer')Output:
understanding the keys (k)
k shape: torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)
what does each dimension mean?
B=4: we have 4 independent sequences (batch)
T=8: each sequence has 8 tokens/positions
head_size=16: each key vector has 16 dimensions
for batch 0, position 0
k[0, 0] = [0.07078436762094498, -0.9486113786697388, -0.5998315811157227, -0.8867908120155334]... (first 4 of 16 values)
this is position 0's "what I contain" vector
for batch 0, position 1
k[0, 1] = [-1.1676945686340332, 0.27537861466407776, 1.6651532649993896, -0.271398663520813]... (first 4 of 16 values)
this is position 1's "what I contain" vector
each position has its own unique key vector
keys tell us what information each token has to offer
# computer queries
q = query(x) # (B, T, head_size) = (4, 8, 16)
qOutput:
tensor([[[-0.4467, 0.6160, -0.5806, -0.4032, -1.2520, -0.3569, 0.3134,
-0.7361, 0.9993, 0.0667, -0.5147, 0.0405, 0.2954, -0.2314,
-0.5547, 0.0771],
[-0.5397, 0.7083, -0.7480, -0.7873, 0.8294, -0.5856, 0.1124,
-0.3189, 0.3401, -0.4997, 0.9458, -0.9430, 0.4883, 0.3419,
-0.2482, -0.1744],
[-0.6314, 0.3053, -0.8942, -0.2582, -1.1889, -0.5720, -0.1693,
-0.6428, 0.3851, -0.3830, -0.7606, 0.9779, -0.4845, -0.7168,
-0.1794, -0.3390],
[-0.4691, -0.1064, 0.0868, -0.1737, 0.3330, 0.1009, 0.0584,
-0.0182, 0.1981, -0.0587, 0.4190, -0.6249, 0.1330, 0.0505,
-0.7110, -0.1839],
[ 0.9760, 0.0929, -0.5380, -0.2953, -0.6723, 0.3561, 0.2937,
-0.0166, -0.4360, -0.4318, -1.0178, -1.0906, -0.0714, 0.4877,
0.9537, 0.1052],
[-0.2460, 0.1963, -0.0330, -0.0296, 0.4262, 0.5681, -0.5607,
0.3858, 0.6359, -0.1462, 0.9695, 0.5166, 0.0834, 0.6468,
-0.6654, 0.6578],
[ 0.1025, -0.6487, -1.0766, 0.5657, -0.5930, 0.7750, -0.3525,
0.5930, -0.7849, -0.0991, 0.1253, 0.2647, 0.0772, 0.1080,
-0.3049, 0.7578],
[ 0.0623, 0.0865, 0.4167, 0.4756, -0.6816, -0.5767, -0.0152,
-0.0497, 0.4178, 0.6291, 0.0292, -0.3211, -0.9214, 0.4931,
0.4157, 0.3642]],
[[ 0.0282, -0.4232, -0.0589, 0.2576, -0.5406, 0.1493, -0.0072,
-0.1857, 0.2729, 0.0872, 0.5680, -0.3410, -0.8786, 0.1489,
-0.3439, -0.5695],
[ 0.3350, -0.0913, 0.3372, 0.5263, 0.2454, 0.7274, -1.1037,
1.2901, -0.2334, 0.2306, -0.2329, 1.0960, -0.5200, -0.2347,
0.5013, 0.1544],
[ 0.3323, -0.3075, -0.5424, -0.2402, -0.6157, -0.2720, 0.2610,
-0.1437, -0.6736, -0.4015, -0.4294, -0.2657, 1.3191, -0.0347,
0.3950, -0.0393],
[ 0.5122, 0.4843, 0.4991, -0.2277, 0.8109, 0.4206, -0.2067,
0.6427, 0.3992, -0.5464, -0.3880, -0.2444, 0.6319, 0.8060,
0.3155, -0.3900],
[-0.6884, 0.1610, -0.9344, -0.0771, -0.7246, -1.3584, 0.2215,
-0.4750, 0.5761, -0.1267, 0.3564, -0.0592, -0.7735, -0.3088,
0.1279, 0.3532],
[-0.6722, 0.3934, -0.1227, -0.3349, 0.1176, 0.2358, 0.4627,
-0.6571, -0.1819, 0.3838, 0.1175, 0.9820, -0.5690, -0.0073,
-0.2621, 0.4604],
[ 0.2283, -0.4608, 0.5600, 0.1657, 1.2347, 0.0294, 0.3925,
-0.0448, 0.1652, -0.7090, 0.3402, 0.5073, 0.4177, 0.8930,
-0.1127, -0.7572],
[-0.3415, 0.4639, -0.0265, -0.0486, -0.1049, -1.0631, -0.2458,
-0.1110, 0.6402, 0.4874, -0.0171, 0.4958, -0.5030, 0.6007,
0.4187, 0.6639]],
[[ 0.9324, -0.3873, 0.7694, 0.2836, 0.0797, -0.5609, -0.0542,
0.0164, -0.3471, -0.6118, -0.6670, 0.0451, 1.4065, 0.4547,
0.7767, -0.5672],
[ 1.4794, 0.4999, 0.5846, -0.0675, 0.1716, -0.3754, 1.4422,
0.3717, -0.0733, -0.2583, 0.1913, -0.0275, 0.3572, 0.7829,
0.0142, -1.2853],
[ 0.6440, -0.0962, 0.4457, 0.8338, -0.3444, -0.3782, -0.0274,
-0.0371, 0.5641, 1.0235, 0.1433, 0.6200, -0.5572, 0.5647,
0.0929, -0.6358],
[-0.5921, 0.2062, -0.0420, 0.2374, 0.7923, 0.6848, -0.2879,
0.6016, 1.3414, -0.4967, 1.6375, -0.5447, -1.2369, 0.8092,
-1.1227, -0.2975],
[-0.8941, -0.2259, -0.0089, -0.2834, 0.5162, -0.1793, -0.7540,
0.0698, 0.4712, -0.6827, -0.3188, -0.1305, 0.2372, -0.4696,
0.1738, -0.6162],
[ 0.2228, 0.2739, -0.3661, -0.3544, -0.1279, 0.1681, 0.3655,
-1.0201, 0.0591, 1.3898, 0.0566, 0.3300, -1.5571, 0.4145,
-0.6352, 0.3029],
[-0.2113, 0.1152, -0.7422, -1.1339, 0.4245, 0.0725, 0.6782,
-0.6515, -0.6226, -0.6749, -0.3655, -0.0432, 0.7225, -0.4106,
-0.5273, -0.3364],
[-0.5688, 0.4958, -0.2908, -0.2989, -0.4141, -2.0439, 0.9589,
-0.4933, -0.5134, -0.6563, -0.4388, -0.0541, 0.3163, -0.6713,
-0.1591, -0.2476]],
[[ 0.0763, -0.7024, 0.1258, -0.1102, 0.2028, -0.0728, 0.9231,
-0.1842, -0.3658, 0.2163, -0.2010, -0.0635, 1.2323, 0.6367,
0.7457, 0.8029],
[-0.1333, -0.1337, -0.3907, 0.3097, -0.4754, 0.6433, -1.5990,
0.2370, 0.7829, -0.2601, -0.9853, 0.1865, -0.0083, 0.3066,
0.1418, 0.3310],
[-0.7890, 0.5769, -0.2716, -0.3727, 0.5939, -0.3055, 0.4481,
-1.1364, 0.7368, 0.1956, 0.3569, 0.0049, -0.1753, 0.1731,
-0.4684, 0.0436],
[ 0.0596, -0.4374, -0.1358, 0.2870, -0.8057, -0.1064, -0.1653,
0.4008, 0.5390, -0.1343, 1.0012, 0.3865, 0.9892, 0.4055,
0.4856, 0.7514],
[ 0.5913, -0.3046, 0.5919, 1.5844, -0.7266, 0.6204, -0.7069,
1.2992, 0.5931, 1.1558, 0.4605, -0.1100, -0.4958, 0.1390,
0.4130, -0.2124],
[-0.3426, -0.5883, 0.2461, 0.4987, 0.2900, 0.2267, 0.2203,
-0.0994, -0.1966, -0.0249, 0.4232, -0.4375, 0.1825, 0.1140,
0.4494, 0.7788],
[ 0.4462, 0.0769, 0.3909, 0.5763, -0.6788, 1.3467, -1.1254,
1.5560, 0.3330, -0.2107, 0.1921, -0.4350, 0.8739, 0.1294,
1.5133, 0.0672],
[-0.6450, -0.1364, 0.1006, 0.1115, 0.1079, 0.7784, -0.0366,
-0.4395, 0.5661, 0.8980, 0.2271, 0.0808, -0.0898, -0.5526,
-0.8941, 0.6091]]], grad_fn=<UnsafeViewBackward0>)
# query shape
q.shapeOutput:
torch.Size([4, 8, 16])
# understand the queries tensor
print('understanding the queries (q)')
print()
print(f'q shape: {q.shape} = (B={B}, T={T}, head_size={head_size})')
print()
print('what does each dimension mean?')
print(f' B={B}: we have {B} independent sequences (batch)')
print(f' T={T}: each sequence has {T} tokens/positions')
print(f' head_size={head_size}: each query vector has {head_size} dimensions')
print()
print('for batch 0, position 0')
print(f' q[0, 0] = {q[0, 0].tolist()[:4]}... (first 4 of {head_size} values)')
print(' this is position 0\'s "what I\'m looking for" vector')
print()
print('for batch 0, position 1')
print(f' q[0, 1] = {q[0, 1].tolist()[:4]}... (first 4 of {head_size} values)')
print(' this is position 1\'s "what I\'m looking for" vector')
print()
print('each position has its own unique query vector')
print('queries tell us what information each token is searching for')Output:
understanding the queries (q)
q shape: torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)
what does each dimension mean?
B=4: we have 4 independent sequences (batch)
T=8: each sequence has 8 tokens/positions
head_size=16: each query vector has 16 dimensions
for batch 0, position 0
q[0, 0] = [-0.44672077894210815, 0.6160160899162292, -0.580593466758728, -0.4032447636127472]... (first 4 of 16 values)
this is position 0's "what I'm looking for" vector
for batch 0, position 1
q[0, 1] = [-0.5397410988807678, 0.7082756161689758, -0.7479986548423767, -0.7873462438583374]... (first 4 of 16 values)
this is position 1's "what I'm looking for" vector
each position has its own unique query vector
queries tell us what information each token is searching for
# compute raw attention scores (affinities)
# wei[b, i, j] = how much position i attends to position j
# computed as dot product of query[i] and key[j]
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) → (B, T, T)
weiOutput:
tensor([[[-0.3332, -1.1723, -1.0216, -0.0545, -1.0950, 0.2735, 0.1340,
-0.8490],
[-0.6597, 0.7869, -1.2725, 1.6851, 0.1159, 0.5450, 0.2356,
-0.1962],
[ 0.3630, -1.5219, 0.7821, -1.7215, -0.3494, 0.2884, -0.1021,
-1.4271],
[-0.1001, 0.8649, -0.0335, 1.0221, -0.1350, -0.3078, 0.1440,
-0.3019],
[ 0.0136, -1.6202, -1.9888, -0.3327, -1.2506, -0.8928, -2.2674,
3.0561],
[-0.5833, 1.2025, -0.3281, 0.9147, 0.9809, -0.4859, 1.7589,
0.1650],
[ 1.1351, -1.9940, 1.5545, -1.8037, -0.5062, -2.6109, -1.0739,
1.6430],
[-1.2784, -0.4554, -1.4118, 0.6392, -0.5780, 1.9291, 1.6689,
0.1103]],
[[ 0.0685, 0.8637, -0.6632, -0.6766, 0.3320, 0.5343, -0.1189,
-0.4954],
[-0.3457, -0.6529, 1.8824, 2.0650, -0.0657, -0.6715, 1.4419,
-0.1583],
[-0.0532, -2.0680, 0.5618, -1.8266, 0.9683, -0.3595, -0.1953,
1.8939],
[ 4.8230, -3.2096, -0.1085, -0.4954, -1.7757, -0.6045, 1.5651,
-0.5931],
[-3.5622, 3.7869, -0.6365, -0.3043, -0.0541, 0.2013, -0.3909,
1.2404],
[-4.2631, 3.2677, -0.0738, 1.3912, 0.6345, 0.5088, -0.2544,
0.8281],
[ 1.6236, -1.5654, -0.7652, -0.7248, -0.2393, -0.7221, 0.2028,
-1.0873],
[-2.2400, 3.2660, -0.4738, 0.7659, -0.9465, 0.1319, 0.7583,
0.7500]],
[[-0.0314, 0.2432, 0.6900, 3.0115, 2.0694, -1.9951, -1.2359,
-1.8439],
[-2.0277, 0.3562, 2.3180, 2.0301, -0.3015, 1.8700, -2.2014,
-2.0738],
[-1.8383, -0.0600, 0.5855, 1.2383, 0.3526, 0.8046, -1.8479,
-0.9355],
[ 0.5299, 1.4638, 2.5979, 0.5391, -3.5146, 1.8997, -1.7077,
-1.5062],
[ 1.6449, -0.6936, -1.2360, 1.3106, 0.4067, -2.5170, 0.9544,
-0.6532],
[-1.9045, -1.1392, -0.1105, -2.7667, -0.8025, 3.8325, -0.2067,
1.8609],
[ 0.5785, -1.2410, -1.2474, -2.1258, 0.6188, -1.0260, 1.5934,
2.1066],
[-2.2108, -1.9236, -2.0073, 1.1402, 2.6994, -3.1374, 1.5955,
-0.2352]],
[[ 0.4770, -0.8036, 0.6485, 1.7210, 1.0219, 0.0428, 0.3335,
-0.0809],
[ 0.1920, 0.0650, 0.6659, -1.6642, 0.2617, 1.2427, -0.0642,
0.4364],
[-1.5678, 0.0473, 0.9261, -0.3383, 1.0800, -0.6310, 1.4611,
0.5233],
[ 1.4212, 0.1287, -0.5043, 1.4374, 0.8206, 0.3293, -0.1394,
0.4498],
[ 0.4617, -0.8008, -3.4562, -1.3081, -0.5987, -0.2240, -2.5389,
-0.7223],
[ 0.4312, -0.4930, -1.0866, 1.0946, -0.0532, 0.1422, -0.5041,
-0.5325],
[ 3.5177, 0.2910, -2.6452, 0.4690, -0.0807, 1.4431, -1.6943,
0.4496],
[-0.4898, 0.4337, -0.4948, -1.2595, -0.5212, -1.1997, -1.4466,
0.0328]]], grad_fn=<UnsafeViewBackward0>)
# understand the attention score computation
print('understanding q @ k.transpose(-2, -1)')
print()
print(f'q shape: {q.shape} = (B, T, head_size) = ({B}, {T}, {head_size})')
print(f'k shape: {k.shape} = (B, T, head_size) = ({B}, {T}, {head_size})')
print()
print('k.transpose(-2, -1) swaps the last two dimensions')
print(f' before: k shape = ({B}, {T}, {head_size})')
print(f' after: k.T shape = ({B}, {head_size}, {T})')
print()
print('the matrix multiplication')
print(f' (B, T, head_size) @ (B, head_size, T) = (B, T, T)')
print(f' ({B}, {T}, {head_size}) @ ({B}, {head_size}, {T}) = ({B}, {T}, {T})')
print()
print('what does wei[b, i, j] mean?')
print(' wei[b, i, j] = dot product of query[i] and key[j] in batch b')
print(' higher value = position i is more interested in position j')
print(' lower value = position i is less interested in position j')
print()
print(f'wei shape: {wei.shape}')Output:
understanding q @ k.transpose(-2, -1)
q shape: torch.Size([4, 8, 16]) = (B, T, head_size) = (4, 8, 16)
k shape: torch.Size([4, 8, 16]) = (B, T, head_size) = (4, 8, 16)
k.transpose(-2, -1) swaps the last two dimensions
before: k shape = (4, 8, 16)
after: k.T shape = (4, 16, 8)
the matrix multiplication
(B, T, head_size) @ (B, head_size, T) = (B, T, T)
(4, 8, 16) @ (4, 16, 8) = (4, 8, 8)
what does wei[b, i, j] mean?
wei[b, i, j] = dot product of query[i] and key[j] in batch b
higher value = position i is more interested in position j
lower value = position i is less interested in position j
wei shape: torch.Size([4, 8, 8])
# trace through one attention score calculation manually
print('tracing through one attention score calculation')
print()
print('let\'s compute wei[0, 0, 1] manually')
print('this is: how much does position 0 attend to position 1 (in batch 0)?')
print()
print(f'q[0, 0] (query at position 0):')
print(f' {q[0, 0].tolist()[:4]}... (first 4 values)')
print()
print(f'k[0, 1] (key at position 1):')
print(f' {k[0, 1].tolist()[:4]}... (first 4 values)')
print()
print('dot product = sum of element-wise products')
manual_dot = (q[0, 0] * k[0, 1]).sum().item()
print(f' q[0, 0] · k[0, 1] = {manual_dot:.4f}')
print()
print(f'actual wei[0, 0, 1] = {wei[0, 0, 1].item():.4f}')
print(f'match: {abs(manual_dot - wei[0, 0, 1].item()) < 1e-5}')Output:
tracing through one attention score calculation
let's compute wei[0, 0, 1] manually
this is: how much does position 0 attend to position 1 (in batch 0)?
q[0, 0] (query at position 0):
[-0.44672077894210815, 0.6160160899162292, -0.580593466758728, -0.4032447636127472]... (first 4 values)
k[0, 1] (key at position 1):
[-1.1676945686340332, 0.27537861466407776, 1.6651532649993896, -0.271398663520813]... (first 4 values)
dot product = sum of element-wise products
q[0, 0] · k[0, 1] = -1.1723
actual wei[0, 0, 1] = -1.1723
match: True
# examine one row of attention scores (before masking)
print('examining raw attention scores for position 0 (batch 0)')
print()
print(f'wei[0, 0] = attention scores from position 0 to all positions')
print(f' {wei[0, 0].tolist()}')
print()
print('interpreting these scores')
for j in range(T):
score = wei[0, 0, j].item()
if score > 0:
print(f' position 0 → position {j}: {score:.4f} (positive = interested)')
else:
print(f' position 0 → position {j}: {score:.4f} (negative = not interested)')
print()
print('note: these are RAW scores, not probabilities yet')
print('we still need to apply masking and softmax')Output:
examining raw attention scores for position 0 (batch 0)
wei[0, 0] = attention scores from position 0 to all positions
[-0.33323800563812256, -1.1722666025161743, -1.0215535163879395, -0.05452167987823486, -1.0950181484222412, 0.27353763580322266, 0.13403844833374023, -0.8489717841148376]
interpreting these scores
position 0 → position 0: -0.3332 (negative = not interested)
position 0 → position 1: -1.1723 (negative = not interested)
position 0 → position 2: -1.0216 (negative = not interested)
position 0 → position 3: -0.0545 (negative = not interested)
position 0 → position 4: -1.0950 (negative = not interested)
position 0 → position 5: 0.2735 (positive = interested)
position 0 → position 6: 0.1340 (positive = interested)
position 0 → position 7: -0.8490 (negative = not interested)
note: these are RAW scores, not probabilities yet
we still need to apply masking and softmax
# create lower-triangular mask
tril = torch.tril(torch.ones(T, T))
trilOutput:
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.]])
# understand the lower triangular mask
print('understanding the lower triangular mask')
print()
print('tril = torch.tril(torch.ones(T, T))')
print(f'T = {T}, so we create an {T}x{T} matrix')
print()
print('examining each row of tril')
for i in range(T):
row = tril[i].tolist()
visible = [j for j, v in enumerate(row) if v == 1.0]
print(f'row {i}: {[int(v) for v in row]}')
print(f' position {i} can see positions {visible}')
print()Output:
understanding the lower triangular mask
tril = torch.tril(torch.ones(T, T))
T = 8, so we create an 8x8 matrix
examining each row of tril
row 0: [1, 0, 0, 0, 0, 0, 0, 0]
position 0 can see positions [0]
row 1: [1, 1, 0, 0, 0, 0, 0, 0]
position 1 can see positions [0, 1]
row 2: [1, 1, 1, 0, 0, 0, 0, 0]
position 2 can see positions [0, 1, 2]
row 3: [1, 1, 1, 1, 0, 0, 0, 0]
position 3 can see positions [0, 1, 2, 3]
row 4: [1, 1, 1, 1, 1, 0, 0, 0]
position 4 can see positions [0, 1, 2, 3, 4]
row 5: [1, 1, 1, 1, 1, 1, 0, 0]
position 5 can see positions [0, 1, 2, 3, 4, 5]
row 6: [1, 1, 1, 1, 1, 1, 1, 0]
position 6 can see positions [0, 1, 2, 3, 4, 5, 6]
row 7: [1, 1, 1, 1, 1, 1, 1, 1]
position 7 can see positions [0, 1, 2, 3, 4, 5, 6, 7]
# apply the mask to the attention scores
wei = wei.masked_fill(tril == 0, float('-inf'))
weiOutput:
tensor([[[-0.3332, -inf, -inf, -inf, -inf, -inf, -inf,
-inf],
[-0.6597, 0.7869, -inf, -inf, -inf, -inf, -inf,
-inf],
[ 0.3630, -1.5219, 0.7821, -inf, -inf, -inf, -inf,
-inf],
[-0.1001, 0.8649, -0.0335, 1.0221, -inf, -inf, -inf,
-inf],
[ 0.0136, -1.6202, -1.9888, -0.3327, -1.2506, -inf, -inf,
-inf],
[-0.5833, 1.2025, -0.3281, 0.9147, 0.9809, -0.4859, -inf,
-inf],
[ 1.1351, -1.9940, 1.5545, -1.8037, -0.5062, -2.6109, -1.0739,
-inf],
[-1.2784, -0.4554, -1.4118, 0.6392, -0.5780, 1.9291, 1.6689,
0.1103]],
[[ 0.0685, -inf, -inf, -inf, -inf, -inf, -inf,
-inf],
[-0.3457, -0.6529, -inf, -inf, -inf, -inf, -inf,
-inf],
[-0.0532, -2.0680, 0.5618, -inf, -inf, -inf, -inf,
-inf],
[ 4.8230, -3.2096, -0.1085, -0.4954, -inf, -inf, -inf,
-inf],
[-3.5622, 3.7869, -0.6365, -0.3043, -0.0541, -inf, -inf,
-inf],
[-4.2631, 3.2677, -0.0738, 1.3912, 0.6345, 0.5088, -inf,
-inf],
[ 1.6236, -1.5654, -0.7652, -0.7248, -0.2393, -0.7221, 0.2028,
-inf],
[-2.2400, 3.2660, -0.4738, 0.7659, -0.9465, 0.1319, 0.7583,
0.7500]],
[[-0.0314, -inf, -inf, -inf, -inf, -inf, -inf,
-inf],
[-2.0277, 0.3562, -inf, -inf, -inf, -inf, -inf,
-inf],
[-1.8383, -0.0600, 0.5855, -inf, -inf, -inf, -inf,
-inf],
[ 0.5299, 1.4638, 2.5979, 0.5391, -inf, -inf, -inf,
-inf],
[ 1.6449, -0.6936, -1.2360, 1.3106, 0.4067, -inf, -inf,
-inf],
[-1.9045, -1.1392, -0.1105, -2.7667, -0.8025, 3.8325, -inf,
-inf],
[ 0.5785, -1.2410, -1.2474, -2.1258, 0.6188, -1.0260, 1.5934,
-inf],
[-2.2108, -1.9236, -2.0073, 1.1402, 2.6994, -3.1374, 1.5955,
-0.2352]],
[[ 0.4770, -inf, -inf, -inf, -inf, -inf, -inf,
-inf],
[ 0.1920, 0.0650, -inf, -inf, -inf, -inf, -inf,
-inf],
[-1.5678, 0.0473, 0.9261, -inf, -inf, -inf, -inf,
-inf],
[ 1.4212, 0.1287, -0.5043, 1.4374, -inf, -inf, -inf,
-inf],
[ 0.4617, -0.8008, -3.4562, -1.3081, -0.5987, -inf, -inf,
-inf],
[ 0.4312, -0.4930, -1.0866, 1.0946, -0.0532, 0.1422, -inf,
-inf],
[ 3.5177, 0.2910, -2.6452, 0.4690, -0.0807, 1.4431, -1.6943,
-inf],
[-0.4898, 0.4337, -0.4948, -1.2595, -0.5212, -1.1997, -1.4466,
0.0328]]], grad_fn=<MaskedFillBackward0>)
# understand the masked_fill operation
print('understanding masked_fill with -inf')
print()
print('wei = wei.masked_fill(tril == 0, float(\'-inf\'))')
print()
print('step 1: tril == 0 creates a boolean mask')
print(' True where we CANNOT look (future positions)')
print(' False where we CAN look (current and past)')
print()
print('step 2: masked_fill replaces True positions with -inf')
print()
print('examining batch 0, position 0')
print(f' before masking: wei[0, 0] could see all positions')
print(f' after masking: wei[0, 0] = {wei[0, 0].tolist()}')
print(' positions 1-7 are now -inf (cannot look at future)')
print()
print('examining batch 0, position 3')
print(f' wei[0, 3] = {wei[0, 3].tolist()}')
print(' positions 0-3 have real values, positions 4-7 are -inf')
print()
print('why -inf?')
print(' because e^(-inf) = 0')
print(' when we apply softmax, -inf values become 0 probability')
print(' this completely blocks information from future positions')Output:
understanding masked_fill with -inf
wei = wei.masked_fill(tril == 0, float('-inf'))
step 1: tril == 0 creates a boolean mask
True where we CANNOT look (future positions)
False where we CAN look (current and past)
step 2: masked_fill replaces True positions with -inf
examining batch 0, position 0
before masking: wei[0, 0] could see all positions
after masking: wei[0, 0] = [-0.33323800563812256, -inf, -inf, -inf, -inf, -inf, -inf, -inf]
positions 1-7 are now -inf (cannot look at future)
examining batch 0, position 3
wei[0, 3] = [-0.10006709396839142, 0.8648553490638733, -0.033506155014038086, 1.0221478939056396, -inf, -inf, -inf, -inf]
positions 0-3 have real values, positions 4-7 are -inf
why -inf?
because e^(-inf) = 0
when we apply softmax, -inf values become 0 probability
this completely blocks information from future positions
# normalize the attention scores to probabilities
wei = F.softmax(wei, dim=-1)
weiOutput:
tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[1.9052e-01, 8.0948e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[3.7418e-01, 5.6820e-02, 5.6900e-01, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[1.2878e-01, 3.3800e-01, 1.3765e-01, 3.9557e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[4.3106e-01, 8.4134e-02, 5.8193e-02, 3.0487e-01, 1.2175e-01,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[5.3737e-02, 3.2051e-01, 6.9361e-02, 2.4035e-01, 2.5680e-01,
5.9238e-02, 0.0000e+00, 0.0000e+00],
[3.3957e-01, 1.4859e-02, 5.1650e-01, 1.7974e-02, 6.5784e-02,
8.0182e-03, 3.7289e-02, 0.0000e+00],
[1.6459e-02, 3.7486e-02, 1.4404e-02, 1.1200e-01, 3.3159e-02,
4.0686e-01, 3.1364e-01, 6.5997e-02]],
[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[5.7619e-01, 4.2381e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[3.3523e-01, 4.4705e-02, 6.2006e-01, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[9.8771e-01, 3.2071e-04, 7.1268e-03, 4.8401e-03, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[6.1209e-04, 9.5163e-01, 1.1413e-02, 1.5912e-02, 2.0434e-02,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[4.0501e-04, 7.5514e-01, 2.6720e-02, 1.1563e-01, 5.4259e-02,
4.7846e-02, 0.0000e+00, 0.0000e+00],
[5.8107e-01, 2.3946e-02, 5.3307e-02, 5.5501e-02, 9.0190e-02,
5.5652e-02, 1.4033e-01, 0.0000e+00],
[3.0533e-03, 7.5160e-01, 1.7858e-02, 6.1689e-02, 1.1131e-02,
3.2726e-02, 6.1221e-02, 6.0719e-02]],
[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[8.4403e-02, 9.1560e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[5.4919e-02, 3.2512e-01, 6.1996e-01, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[8.0239e-02, 2.0417e-01, 6.3460e-01, 8.0984e-02, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[4.6332e-01, 4.4701e-02, 2.5986e-02, 3.3167e-01, 1.3432e-01,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[3.0988e-03, 6.6612e-03, 1.8633e-02, 1.3084e-03, 9.3273e-03,
9.6097e-01, 0.0000e+00, 0.0000e+00],
[1.8549e-01, 3.0068e-02, 2.9878e-02, 1.2412e-02, 1.9312e-01,
3.7283e-02, 5.1176e-01, 0.0000e+00],
[4.5383e-03, 6.0481e-03, 5.5622e-03, 1.2949e-01, 6.1569e-01,
1.7967e-03, 2.0416e-01, 3.2725e-02]],
[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[5.3170e-01, 4.6830e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[5.5138e-02, 2.7724e-01, 6.6763e-01, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[4.1039e-01, 1.1268e-01, 5.9834e-02, 4.1710e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[5.4961e-01, 1.5550e-01, 1.0927e-02, 9.3633e-02, 1.9033e-01,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[2.0316e-01, 8.0623e-02, 4.4527e-02, 3.9438e-01, 1.2515e-01,
1.5217e-01, 0.0000e+00, 0.0000e+00],
[8.0151e-01, 3.1811e-02, 1.6882e-03, 3.8008e-02, 2.1936e-02,
1.0068e-01, 4.3691e-03, 0.0000e+00],
[1.1754e-01, 2.9599e-01, 1.1695e-01, 5.4438e-02, 1.1391e-01,
5.7793e-02, 4.5152e-02, 1.9822e-01]]], grad_fn=<SoftmaxBackward0>)
# understand softmax normalization
print('understanding softmax normalization')
print()
print('wei = F.softmax(wei, dim=-1)')
print()
print('softmax converts raw scores to probabilities')
print(' - all values become positive (e^x > 0)')
print(' - all values in a row sum to 1')
print(' - e^(-inf) = 0 (masked positions become 0)')
print()
print('examining batch 0, position 0')
print(f' wei[0, 0] = {[round(v, 4) for v in wei[0, 0].tolist()]}')
print(f' sum = {wei[0, 0].sum().item():.4f}')
print(' only position 0 is visible, so it gets weight 1.0')
print()
print('examining batch 0, position 3')
print(f' wei[0, 3] = {[round(v, 4) for v in wei[0, 3].tolist()]}')
print(f' sum = {wei[0, 3].sum().item():.4f}')
print(' positions 0-3 have non-zero weights that sum to 1')
print()
print('this is DATA-DEPENDENT attention!')
print('the weights are NOT uniform (like 0.25, 0.25, 0.25, 0.25)')
print('instead, they depend on the query-key dot products')Output:
understanding softmax normalization
wei = F.softmax(wei, dim=-1)
softmax converts raw scores to probabilities
- all values become positive (e^x > 0)
- all values in a row sum to 1
- e^(-inf) = 0 (masked positions become 0)
examining batch 0, position 0
wei[0, 0] = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
sum = 1.0000
only position 0 is visible, so it gets weight 1.0
examining batch 0, position 3
wei[0, 3] = [0.1288, 0.338, 0.1376, 0.3956, 0.0, 0.0, 0.0, 0.0]
sum = 1.0000
positions 0-3 have non-zero weights that sum to 1
this is DATA-DEPENDENT attention!
the weights are NOT uniform (like 0.25, 0.25, 0.25, 0.25)
instead, they depend on the query-key dot products
# compute values
v = value(x) # (B, T, head_size)
vOutput:
tensor([[[ 0.7630, -0.2412, -0.4150, 0.3833, 0.5740, -1.6738, 0.7954,
0.6872, -0.3848, 0.5073, -0.5312, -0.1221, 0.0445, 1.2169,
0.9940, 1.5281],
[ 0.3218, -0.0569, -0.8477, -0.7261, 0.0893, -0.1100, -0.0939,
-1.0305, 0.0200, 0.2691, 0.5359, 0.1426, 0.1681, 0.3577,
0.2332, -0.5051],
[-0.1803, 0.2362, 0.1637, 0.5017, 0.7742, -0.4373, -0.1290,
-0.2560, -0.4637, 0.2134, 0.0475, -0.3715, -0.8919, 0.1818,
1.0010, 0.8075],
[-0.3670, 0.6613, 0.3709, 0.2785, -0.5886, -0.1584, -0.5308,
-0.9533, -0.2020, 0.1104, 0.0643, 0.5399, 0.8794, 0.3147,
-0.3706, -0.2829],
[-0.2686, 0.1055, -0.7112, 1.0555, 0.3917, -0.2944, -1.1679,
0.3537, 0.7702, -0.4535, -0.7852, -0.2078, -0.5898, -0.0895,
0.1527, -0.2386],
[-0.3207, -0.1223, 0.2610, -0.5424, -0.0771, -0.4658, 0.2526,
-0.0036, -0.5229, 0.1433, 0.9100, 0.2990, 0.4837, 0.2490,
-0.0583, -0.0475],
[ 0.5571, -0.0288, -0.2621, -0.6581, -0.5829, -0.7672, 0.0551,
-0.7210, -0.0195, -0.4331, -0.6869, -1.0376, -0.0704, 0.3517,
-0.4246, 0.4856],
[-0.1362, -0.4772, -0.0267, -0.5559, -0.2265, 0.1733, 0.3853,
0.6409, -0.2373, -0.0300, 0.1528, 0.1062, 0.1487, 0.1704,
0.5757, 0.6982]],
[[-1.0970, 0.4693, -1.4577, 0.1591, -0.5800, 0.4218, -0.2913,
-0.4384, -0.6460, -0.3378, -0.5066, 1.4041, 0.3326, -0.3679,
0.0863, 0.9563],
[ 0.0487, 0.0677, 0.8809, 0.6150, 0.0918, 0.0283, 0.0851,
-0.1105, 0.8593, 0.1873, -0.1565, -1.3734, -0.3197, -0.5941,
-0.4458, 0.1768],
[ 0.6274, 0.3655, -0.3945, 0.1603, -0.4681, -0.1591, 0.1339,
0.1754, -0.0278, -0.0900, 0.1749, 0.2202, -0.1974, 1.3567,
0.1852, -0.1852],
[ 0.1319, -0.7497, -0.1804, -0.6604, 0.3357, -0.4459, 0.3611,
0.0562, 1.0438, 0.5156, 0.6546, -0.8129, -0.4794, -0.4133,
0.4109, -0.0929],
[-0.4059, 0.1516, -0.1648, 0.6007, 0.2016, 0.6715, 0.2267,
-0.3642, -0.9720, 0.3649, -0.5970, -0.0378, -0.2092, 0.2979,
1.2460, 0.0566],
[-0.6157, -0.6624, -0.3937, 0.1310, 0.4711, 0.8800, 0.2762,
0.7958, 0.4129, 0.7078, -0.0498, 1.0576, 0.3581, -0.2541,
0.6451, 0.3365],
[ 0.8644, 0.4706, -0.1293, -0.6170, -0.5255, 0.4717, 0.8747,
-1.1380, 0.0565, 0.7913, 0.8005, -0.2364, 1.0347, -1.0927,
-0.1297, 0.2886],
[ 0.9694, -0.2682, 0.1122, -0.4056, 0.7449, 0.2006, 0.4392,
0.9176, 0.3531, 0.7124, 0.8635, -1.1062, -0.7719, 0.8359,
0.3966, -0.0352]],
[[ 0.3106, 0.9371, -0.2092, 0.7705, 0.0193, 0.5726, 0.2484,
-0.9781, 0.6597, 0.5011, 0.7583, -0.0264, 0.7907, 0.1997,
-0.4586, -0.6579],
[ 0.0273, 0.3817, 0.6696, -0.6766, 0.0330, -0.3337, -0.5152,
-1.1344, -0.6060, 0.4816, -0.1285, -0.0487, 0.3002, -1.7130,
-0.8980, -0.8737],
[-0.5203, -0.0333, 0.4968, -0.5467, -0.2312, 0.2681, 0.2471,
-0.1932, 0.3291, -0.0638, 0.4587, -0.4808, -0.4177, -0.5063,
0.2593, 0.0968],
[ 0.2236, 0.1102, -0.3086, -0.4969, -0.5499, -0.7894, -0.3272,
-0.7428, -0.8734, -0.0728, 0.1639, 0.1849, 0.8782, -0.4819,
0.0294, 0.7049],
[ 0.0931, 0.3978, -0.1147, -0.2994, 0.4704, 0.0658, 0.4794,
0.0721, -0.4440, -0.1687, 1.2164, -0.1396, -1.5568, 0.2090,
0.9272, -0.4725],
[-0.2449, -1.0554, -0.6655, 0.4118, -0.3106, 0.9528, -0.1242,
0.7541, 1.5634, 0.6041, -0.4255, 0.1662, 0.5790, -0.1300,
0.5525, 0.5753],
[ 0.1237, -0.1306, -0.1187, 0.1009, -0.3313, 0.2688, -0.8307,
0.3135, -0.3795, -0.3788, 0.0449, 0.1885, 0.5644, 0.4173,
0.1475, -0.7490],
[ 0.1074, 0.4722, 0.4970, 0.7408, 0.9470, -0.0367, -0.7831,
-0.5337, -0.5380, 0.2420, -0.7910, 0.7542, -0.0514, 0.7723,
-0.7424, -0.7498]],
[[ 0.0771, -0.1557, -0.9221, -0.3849, 0.4312, 0.7572, 0.2792,
0.5681, 1.3448, 0.3264, 0.1471, 0.4326, -1.2313, 0.6333,
-0.2560, -1.0371],
[-0.6234, -0.1336, -1.6580, 0.5568, 0.8131, -0.3600, 0.6075,
0.4096, 0.7376, -0.1125, 0.0733, -0.1014, -0.2114, 0.8178,
1.1383, 0.8535],
[ 0.2804, -0.0703, 0.5089, -0.2839, 0.6286, 0.4214, 0.1597,
-0.3707, -0.1932, 0.7346, 0.5754, -1.1367, -0.2369, -0.3359,
0.6683, 0.1314],
[-1.1991, 0.0615, -0.8441, -0.5543, 0.3382, 0.0924, 0.6773,
-0.1560, -0.7958, -0.1810, -0.5697, 0.9682, -0.7534, 0.3174,
0.2770, 0.0683],
[-0.3574, 0.1819, 0.5249, 0.2009, -0.9431, 0.5530, -0.6007,
0.2637, 0.4426, -0.8211, 0.3333, -0.9226, 0.0535, -0.0387,
-0.6463, 0.3352],
[-0.1651, -0.0289, -0.4991, -0.4394, 0.0470, 0.8903, 0.6546,
0.2552, 0.1525, 0.4128, -0.2636, -0.0825, -0.3186, -0.3558,
0.4890, -0.0732],
[ 0.4685, 0.3434, -0.8854, 0.8242, -0.2039, -0.5476, 0.0932,
0.3258, 0.5255, -1.1008, 0.7351, 0.0765, -0.1057, 0.6164,
-0.4499, 0.7024],
[-0.0696, 0.0759, 0.5105, 0.2474, 0.0713, 0.0825, 0.2319,
-0.0218, -0.1033, 0.4329, -0.0841, -0.6620, 0.1741, 0.2497,
0.1832, 0.6219]]], grad_fn=<UnsafeViewBackward0>)
# understand values
print('understanding values (v)')
print()
print('v = value(x) projects each token to a "value" representation')
print(f'v shape: {v.shape} = (B={B}, T={T}, head_size={head_size})')
print()
print('what does each value vector represent?')
print(' it\'s the "content" that each token will contribute')
print(' if position j gets high attention weight, its value contributes more')
print()
print('for batch 0, position 0')
print(f' v[0, 0] = {v[0, 0].tolist()[:4]}... (first 4 of {head_size} values)')
print(' this is what position 0 will "give" when attended to')
print()
print('for batch 0, position 1')
print(f' v[0, 1] = {v[0, 1].tolist()[:4]}... (first 4 of {head_size} values)')
print(' this is what position 1 will "give" when attended to')Output:
understanding values (v)
v = value(x) projects each token to a "value" representation
v shape: torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)
what does each value vector represent?
it's the "content" that each token will contribute
if position j gets high attention weight, its value contributes more
for batch 0, position 0
v[0, 0] = [0.7629690766334534, -0.24118372797966003, -0.4150242507457733, 0.3832956552505493]... (first 4 of 16 values)
this is what position 0 will "give" when attended to
for batch 0, position 1
v[0, 1] = [0.3217601180076599, -0.0569150447845459, -0.8477029800415039, -0.7260561585426331]... (first 4 of 16 values)
this is what position 1 will "give" when attended to
# compute the weighted sum of the values
out = wei @ v # (B, T, T) @ (B, T, 16) → (B, T, 16)
outOutput:
tensor([[[ 7.6297e-01, -2.4118e-01, -4.1502e-01, 3.8330e-01, 5.7404e-01,
-1.6738e+00, 7.9543e-01, 6.8724e-01, -3.8477e-01, 5.0733e-01,
-5.3124e-01, -1.2214e-01, 4.4479e-02, 1.2169e+00, 9.9396e-01,
1.5281e+00],
[ 4.0582e-01, -9.2022e-02, -7.6527e-01, -5.1470e-01, 1.8168e-01,
-4.0795e-01, 7.5564e-02, -7.0327e-01, -5.7126e-02, 3.1450e-01,
3.3258e-01, 9.2198e-02, 1.4456e-01, 5.2138e-01, 3.7812e-01,
-1.1776e-01],
[ 2.0116e-01, 4.0931e-02, -1.1034e-01, 3.8763e-01, 6.6036e-01,
-8.8138e-01, 2.1889e-01, 5.2931e-02, -4.0671e-01, 3.2654e-01,
-1.4131e-01, -2.4897e-01, -4.8128e-01, 5.7909e-01, 9.5475e-01,
1.0026e+00],
[ 3.7002e-02, 2.4381e-01, -1.7073e-01, -1.6806e-02, -2.2144e-02,
-3.7560e-01, -1.5701e-01, -6.7214e-01, -1.8652e-01, 2.2933e-01,
1.4469e-01, 1.9491e-01, 2.8767e-01, 4.2711e-01, 1.9802e-01,
2.5314e-02],
[ 2.0086e-01, 1.1945e-01, -2.1421e-01, 3.4675e-01, 1.6826e-01,
-8.4036e-01, 2.3462e-02, -5.2923e-02, -1.5896e-01, 2.3218e-01,
-2.5714e-01, 7.7031e-02, 1.7773e-01, 6.5026e-01, 4.1195e-01,
5.4791e-01],
[-4.4578e-02, 1.6398e-01, -3.6067e-01, 1.2856e-01, 6.7733e-02,
-2.9680e-01, -4.0882e-01, -4.4963e-01, 7.1842e-02, 4.6864e-02,
1.4222e-02, 1.0751e-01, 8.2987e-02, 2.6006e-01, 1.4429e-01,
-1.5585e-01],
[ 1.6465e-01, 5.6039e-02, -1.1679e-01, 4.2405e-01, 5.8894e-01,
-8.5044e-01, 1.1978e-01, 6.5047e-02, -3.2777e-01, 2.4363e-01,
-2.1672e-01, -2.7148e-01, -4.6480e-01, 5.2730e-01, 8.4511e-01,
9.2544e-01],
[ 7.2800e-03, -1.5442e-02, 3.9108e-03, -4.1127e-01, -2.5813e-01,
-4.8418e-01, 5.5030e-02, -3.1134e-01, -2.4386e-01, -6.0700e-02,
1.5806e-01, -1.4520e-01, 2.5764e-01, 2.9121e-01, -1.1582e-01,
1.5733e-01]],
[[-1.0970e+00, 4.6927e-01, -1.4577e+00, 1.5908e-01, -5.8003e-01,
4.2184e-01, -2.9130e-01, -4.3844e-01, -6.4602e-01, -3.3779e-01,
-5.0662e-01, 1.4041e+00, 3.3257e-01, -3.6793e-01, 8.6313e-02,
9.5632e-01],
[-6.1147e-01, 2.9910e-01, -4.6657e-01, 3.5230e-01, -2.9529e-01,
2.5505e-01, -1.3177e-01, -2.9946e-01, -8.0582e-03, -1.1524e-01,
-3.5823e-01, 2.2693e-01, 5.6150e-02, -4.6380e-01, -1.3922e-01,
6.2596e-01],
[ 2.3433e-02, 3.8697e-01, -6.9387e-01, 1.8022e-01, -4.8061e-01,
4.4059e-02, -1.0833e-02, -4.3189e-02, -1.9537e-01, -1.6067e-01,
-6.8359e-02, 5.4580e-01, -2.5206e-02, 6.9131e-01, 1.2385e-01,
2.1368e-01],
[-1.0784e+00, 4.6250e-01, -1.4431e+00, 1.5527e-01, -5.7458e-01,
4.1337e-01, -2.8499e-01, -4.3156e-01, -6.3295e-01, -3.3172e-01,
-4.9603e-01, 1.3840e+00, 3.2465e-01, -3.5593e-01, 8.8418e-02,
9.4286e-01],
[ 4.6625e-02, 6.0098e-02, 8.2663e-01, 5.8893e-01, 9.1147e-02,
3.1998e-02, 9.2721e-02, -1.0999e-01, 8.1375e-01, 1.9270e-01,
-1.4900e-01, -1.3173e+00, -3.1815e-01, -5.5063e-01, -3.9011e-01,
1.6641e-01],
[ 1.6852e-02, -4.9043e-02, 6.0541e-01, 4.3124e-01, 1.2890e-01,
4.4272e-02, 1.3500e-01, -5.4134e-02, 7.3558e-01, 2.5220e-01,
-7.2767e-02, -1.0761e+00, -2.9618e-01, -4.5633e-01, -1.8570e-01,
1.3738e-01],
[-5.4510e-01, 2.9503e-01, -9.1187e-01, 5.3948e-02, -3.7050e-01,
3.8831e-01, 1.8516e-02, -3.9320e-01, -3.5511e-01, 1.5370e-02,
-1.9676e-01, 7.7186e-01, 2.9472e-01, -3.1925e-01, 2.0224e-01,
6.0923e-01],
[ 1.3970e-01, 5.1662e-03, 6.2362e-01, 3.7341e-01, 1.1032e-01,
6.9541e-02, 1.7953e-01, -6.9769e-02, 7.3536e-01, 2.8889e-01,
1.7516e-02, -1.1217e+00, -2.4648e-01, -4.7009e-01, -2.5505e-01,
1.5395e-01]],
[[ 3.1063e-01, 9.3706e-01, -2.0925e-01, 7.7045e-01, 1.9348e-02,
5.7261e-01, 2.4843e-01, -9.7810e-01, 6.5967e-01, 5.0115e-01,
7.5831e-01, -2.6434e-02, 7.9072e-01, 1.9970e-01, -4.5862e-01,
-6.5785e-01],
[ 5.1229e-02, 4.2860e-01, 5.9542e-01, -5.5449e-01, 3.1818e-02,
-2.5724e-01, -4.5078e-01, -1.1213e+00, -4.9919e-01, 4.8322e-01,
-5.3645e-02, -4.6865e-02, 3.4156e-01, -1.5516e+00, -8.6092e-01,
-8.5549e-01],
[-2.9662e-01, 1.5494e-01, 5.1421e-01, -5.1660e-01, -1.3153e-01,
8.9164e-02, -6.5588e-04, -5.4235e-01, 4.3208e-02, 1.4455e-01,
2.8422e-01, -3.1536e-01, -1.1797e-01, -8.5984e-01, -1.5641e-01,
-2.6020e-01],
[-2.8156e-01, 1.4094e-01, 4.1021e-01, -4.6351e-01, -1.8295e-01,
8.4020e-02, 4.5066e-02, -4.9289e-01, 6.7285e-02, 9.2170e-02,
3.3895e-01, -3.0220e-01, -6.9244e-02, -6.9405e-01, -5.3239e-02,
-1.1269e-01],
[ 2.1831e-01, 5.4036e-01, -1.7186e-01, 1.0750e-01, -1.1477e-01,
4.3744e-03, 5.4357e-02, -7.4557e-01, -6.2228e-02, 2.0526e-01,
5.7527e-01, 1.5657e-02, 4.5107e-01, -1.2896e-01, -1.1162e-01,
-1.7100e-01],
[-2.4276e-01, -1.0055e+00, -6.2789e-01, 3.8000e-01, -2.9882e-01,
9.1979e-01, -1.1335e-01, 7.1017e-01, 1.5012e+00, 5.8242e-01,
-3.8730e-01, 1.4929e-01, 5.3973e-01, -1.4386e-01, 5.3706e-01,
5.4326e-01],
[ 1.1782e-01, 1.5628e-01, -1.1536e-01, 1.0920e-01, -9.9444e-02,
2.8018e-01, -3.0325e-01, -2.8041e-02, -1.1854e-01, -9.9304e-02,
3.9456e-01, 5.7266e-02, 1.6387e-01, 2.1351e-01, 1.7117e-01,
-5.8975e-01],
[ 1.1331e-01, 2.5244e-01, -1.1388e-01, -2.0675e-01, 1.8021e-01,
-4.2304e-03, 5.6730e-02, -1.6245e-02, -4.7757e-01, -1.7682e-01,
7.5786e-01, -1.6106e-03, -7.2714e-01, 1.6426e-01, 5.7538e-01,
-3.8376e-01]],
[[ 7.7075e-02, -1.5568e-01, -9.2207e-01, -3.8491e-01, 4.3124e-01,
7.5718e-01, 2.7922e-01, 5.6814e-01, 1.3448e+00, 3.2643e-01,
1.4705e-01, 4.3255e-01, -1.2313e+00, 6.3330e-01, -2.5603e-01,
-1.0371e+00],
[-2.5095e-01, -1.4535e-01, -1.2667e+00, 5.6085e-02, 6.1009e-01,
2.3402e-01, 4.3293e-01, 4.9392e-01, 1.0604e+00, 1.2088e-01,
1.1250e-01, 1.8252e-01, -7.5366e-01, 7.1970e-01, 3.9696e-01,
-1.5175e-01],
[ 1.8618e-02, -9.2575e-02, -1.7076e-01, -5.6385e-02, 6.6891e-01,
2.2329e-01, 2.9041e-01, -1.0262e-01, 1.4965e-01, 4.7727e-01,
4.1259e-01, -7.6315e-01, -2.8469e-01, 3.7417e-02, 7.4762e-01,
2.6716e-01],
[-5.2196e-01, -5.7511e-02, -8.8683e-01, -3.4342e-01, 4.4727e-01,
3.3393e-01, 4.7508e-01, 1.9206e-01, 2.9150e-01, 8.9754e-02,
-1.3457e-01, 5.0191e-01, -8.5755e-01, 4.6433e-01, 1.7871e-01,
-2.9310e-01],
[-2.3181e-01, -6.6739e-02, -7.3817e-01, -1.4173e-01, 2.2250e-01,
4.7868e-01, 1.9876e-01, 4.0749e-01, 8.6140e-01, -3.2872e-03,
1.0859e-01, 1.2460e-01, -7.7254e-01, 4.9392e-01, -5.3462e-02,
-3.6565e-01],
[-5.6486e-01, -2.9277e-03, -6.4147e-01, -3.0628e-01, 2.0365e-01,
3.8468e-01, 4.0435e-01, 1.4225e-01, 8.8813e-02, -2.1360e-02,
-1.6166e-01, 2.8291e-01, -6.1666e-01, 2.4582e-01, 1.7229e-01,
-7.8285e-02],
[-2.5574e-02, -1.2424e-01, -8.6561e-01, -3.4857e-01, 3.6858e-01,
6.9903e-01, 3.2227e-01, 4.9475e-01, 1.0981e+00, 2.7116e-01,
8.3502e-02, 3.5014e-01, -1.0540e+00, 5.1113e-01, -1.2425e-01,
-7.9824e-01],
[-2.5083e-01, -1.3141e-02, -4.9340e-01, 1.3992e-01, 2.8353e-01,
1.4285e-01, 2.8776e-01, 1.9136e-01, 3.7297e-01, 4.7570e-02,
1.1451e-01, -2.9702e-01, -2.5862e-01, 3.4686e-01, 3.7072e-01,
3.3875e-01]]], grad_fn=<UnsafeViewBackward0>)
# understand the final output
print('understanding the weighted sum: out = wei @ v')
print()
print(f'wei shape: {wei.shape} = (B={B}, T={T}, T={T})')
print(f'v shape: {v.shape} = (B={B}, T={T}, head_size={head_size})')
print(f'out shape: {out.shape} = (B={B}, T={T}, head_size={head_size})')
print()
print('what does out[b, i] represent?')
print(' it\'s a weighted average of all value vectors')
print(' weighted by how much position i attends to each position')
print()
print('for position 0 (can only see itself)')
print(f' wei[0, 0] = {[round(v, 3) for v in wei[0, 0].tolist()]}')
print(' out[0, 0] = 1.0 * v[0, 0] + 0.0 * v[0, 1] + ... + 0.0 * v[0, 7]')
print(' out[0, 0] ≈ v[0, 0] (just itself)')
print()
print('for position 3 (can see positions 0, 1, 2, 3)')
weights_3 = [round(w, 3) for w in wei[0, 3].tolist()]
print(f' wei[0, 3] = {weights_3}')
print(f' out[0, 3] = {weights_3[0]} * v[0, 0] + {weights_3[1]} * v[0, 1] + {weights_3[2]} * v[0, 2] + {weights_3[3]} * v[0, 3]')
print(' out[0, 3] is a weighted mix of values from positions 0-3')Output:
understanding the weighted sum: out = wei @ v
wei shape: torch.Size([4, 8, 8]) = (B=4, T=8, T=8)
v shape: torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)
out shape: torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)
what does out[b, i] represent?
it's a weighted average of all value vectors
weighted by how much position i attends to each position
for position 0 (can only see itself)
wei[0, 0] = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
out[0, 0] = 1.0 * v[0, 0] + 0.0 * v[0, 1] + ... + 0.0 * v[0, 7]
out[0, 0] ≈ v[0, 0] (just itself)
for position 3 (can see positions 0, 1, 2, 3)
wei[0, 3] = [0.129, 0.338, 0.138, 0.396, 0.0, 0.0, 0.0, 0.0]
out[0, 3] = 0.129 * v[0, 0] + 0.338 * v[0, 1] + 0.138 * v[0, 2] + 0.396 * v[0, 3]
out[0, 3] is a weighted mix of values from positions 0-3
# manually verify one output calculation
print('manually verifying out[0, 0]')
print()
print('out[0, 0] should equal v[0, 0] since position 0 only sees itself')
print()
print(f'out[0, 0] (first 4 values): {out[0, 0].tolist()[:4]}')
print(f'v[0, 0] (first 4 values): {v[0, 0].tolist()[:4]}')
print()
print(f'are they equal? {torch.allclose(out[0, 0], v[0, 0])}')Output:
manually verifying out[0, 0]
out[0, 0] should equal v[0, 0] since position 0 only sees itself
out[0, 0] (first 4 values): [0.7629690766334534, -0.24118372797966003, -0.4150242507457733, 0.3832956552505493]
v[0, 0] (first 4 values): [0.7629690766334534, -0.24118372797966003, -0.4150242507457733, 0.3832956552505493]
are they equal? True
# manually verify position 1 calculation
print('manually verifying out[0, 1]')
print()
print('position 1 can see positions 0 and 1')
print(f'wei[0, 1] = {[round(w, 4) for w in wei[0, 1].tolist()]}')
print()
print('computing the weighted sum manually')
w0 = wei[0, 1, 0].item()
w1 = wei[0, 1, 1].item()
print(f' weight for position 0: {w0:.4f}')
print(f' weight for position 1: {w1:.4f}')
print()
manual_out = w0 * v[0, 0] + w1 * v[0, 1]
print(f'manual calculation:')
print(f' {w0:.4f} * v[0, 0] + {w1:.4f} * v[0, 1]')
print(f' = {manual_out.tolist()[:4]}... (first 4 values)')
print()
print(f'actual out[0, 1]:')
print(f' = {out[0, 1].tolist()[:4]}... (first 4 values)')
print()
print(f'are they equal? {torch.allclose(manual_out, out[0, 1])}')Output:
manually verifying out[0, 1]
position 1 can see positions 0 and 1
wei[0, 1] = [0.1905, 0.8095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
computing the weighted sum manually
weight for position 0: 0.1905
weight for position 1: 0.8095
manual calculation:
0.1905 * v[0, 0] + 0.8095 * v[0, 1]
= [0.4058188199996948, -0.09202173352241516, -0.7652695178985596, -0.5147035717964172]... (first 4 values)
actual out[0, 1]:
= [0.4058188199996948, -0.09202173352241516, -0.7652694582939148, -0.5147035717964172]... (first 4 values)
are they equal? True
Let's look at what the attention pattern looks like. Remember that these are now data-dependent weights, not uniform averages!
| Concept | Explanation |
|---|---|
| Communication Mechanism | Tokens can "talk" to each other and share information. |
| No Position Awareness | Attention is set-based - we need positional encodings. |
| Batch Independence | Sequences in a batch don't interact. |
| Decoder vs Encoder | Decoder uses causal mask (triangle); Encoder allows full attention. |
| Self vs Cross Attention | Self: Q, K, V from same source; Cross: Q from one source, K/V from another. |
print('Attention Weights for First Sequence (batch 0)')
print(' Rows = \'from\' position (query)')
print(' Cols = \'to\' position (key)')
print()
print(' ', end='')
for j in range(T):
print(f' pos{j} ', end='')
print()
print(' ' + '-' * 55)
for i in range(T):
print(f' pos{i} |', end='')
for j in range(T):
val = wei[0, i, j].item()
if val > 0.001:
print(f' {val:.3f}', end=' ')
else:
print(f' --- ', end='')
print()
print()
print('Observations')
print(' - Lower triangular (can\'t attend to future).')
print(' - Each row sums to 1 (valid probability distribution).')
print(' - Values are NOT uniform - they\'re learned!')
print(' - Different positions attend differently to past tokens.')Output:
Attention Weights for First Sequence (batch 0)
Rows = 'from' position (query)
Cols = 'to' position (key)
pos0 pos1 pos2 pos3 pos4 pos5 pos6 pos7
-------------------------------------------------------
pos0 | 1.000 --- --- --- --- --- --- ---
pos1 | 0.191 0.809 --- --- --- --- --- ---
pos2 | 0.374 0.057 0.569 --- --- --- --- ---
pos3 | 0.129 0.338 0.138 0.396 --- --- --- ---
pos4 | 0.431 0.084 0.058 0.305 0.122 --- --- ---
pos5 | 0.054 0.321 0.069 0.240 0.257 0.059 --- ---
pos6 | 0.340 0.015 0.517 0.018 0.066 0.008 0.037 ---
pos7 | 0.016 0.037 0.014 0.112 0.033 0.407 0.314 0.066
Observations
- Lower triangular (can't attend to future).
- Each row sums to 1 (valid probability distribution).
- Values are NOT uniform - they're learned!
- Different positions attend differently to past tokens.
There's one more trick: we scale the attention scores by
When we compute Q @ K^T, the dot product grows larger as head_size increases.
For vectors with unit variance.
- If head_size = 16, the dot product has variance ≈ 16
- If head_size = 64, the dot product has variance ≈ 64
- If head_size = 512, the dot product has variance ≈ 512
Large values going into softmax cause problems.
- Softmax becomes "peaky" (one position gets nearly all attention)
- Gradients become very small (vanishing gradient problem)
- Training becomes unstable
By dividing by sqrt(head_size), we bring the variance back to ~1.
- variance of Q @ K^T ≈ head_size
- variance of (Q @ K^T) / sqrt(head_size) ≈ 1
This is why the formula in "Attention is All You Need" is.
# create random keys with unit variance
k = torch.randn(B, T, head_size)
kOutput:
tensor([[[-0.6744, -1.8893, -1.8424, 0.1323, -0.7929, 1.2297, 0.0777,
1.8036, -0.3388, -0.4670, -0.4019, -1.3110, 0.0308, -0.5922,
-1.1771, 1.7409],
[-0.2961, -0.3474, -0.4967, -1.3010, 1.3099, -0.2666, 0.1970,
-0.6992, 1.1396, 0.1912, -0.0095, 0.3546, -0.4238, 1.0712,
2.7125, -0.1935],
[ 1.7503, -0.1117, -0.8220, 0.7975, -0.7685, 1.5376, -1.7771,
-1.0646, 1.0508, 1.3841, -1.5027, -1.0865, 2.1496, -0.9262,
-0.8618, -0.0133],
[ 0.9761, -0.0773, -2.1688, 1.2137, -1.8086, 0.1943, 0.6680,
-1.1589, -0.7162, -1.0271, -1.4785, 0.0458, -0.1069, 0.3531,
0.3302, -0.5309],
[ 0.0363, 2.4673, -0.1655, -0.3069, 1.4189, -0.4566, -1.5976,
0.7736, -0.6360, -0.2510, 0.7005, 1.4388, -1.0685, -0.1663,
0.5176, -0.7325],
[ 0.3359, -0.7604, 0.0566, -1.5039, -0.4485, 0.5257, 0.2619,
0.7167, -0.6965, 0.8436, 1.9249, -0.3405, -0.4329, 1.3084,
0.4293, 0.0712],
[-1.4018, 0.5611, 1.1513, 0.6989, -0.5898, -0.1646, -0.4931,
0.5041, 0.1377, 0.2751, 0.4683, -0.7030, -0.1796, 0.8974,
0.0517, -0.5315],
[ 0.4069, 0.4082, -0.4961, -0.9291, -0.1993, 0.4683, 1.0864,
-0.4892, -0.0861, 0.6074, 0.2278, -0.6186, 1.1309, -0.1208,
1.6047, 0.0861]],
[[ 0.2813, 0.0870, -0.2571, 2.2180, 1.2402, -0.6573, 1.8484,
-1.1966, -0.4539, 1.4244, 2.2692, 1.3105, -0.3179, -0.3774,
2.2604, -0.3310],
[-0.7194, 1.2199, 1.4356, -0.3140, 0.8979, 0.6359, -0.8476,
-0.0931, -0.3936, -0.0248, -0.3633, -0.6941, -0.9816, -0.0556,
-1.0469, -0.1615],
[-0.5718, -1.5561, -0.9633, -0.4366, -0.0085, 0.0446, -0.3537,
0.1575, -1.1567, 1.8156, -2.0921, -0.6517, 1.1426, -0.7538,
-1.4663, 0.0802],
[-0.6317, -0.7410, 1.8064, 0.9378, -0.3845, 0.6585, 0.7617,
-0.6451, -3.6308, -2.1864, 0.2644, -0.5599, 1.4537, -0.2963,
-0.4702, -1.4991],
[ 2.2968, 1.6495, 1.3179, 0.7556, 1.2472, 0.7881, 1.5493,
-0.6089, -2.7026, -0.6109, 1.1898, -0.4802, 2.2536, 1.1718,
0.8793, -0.7797],
[-0.0781, -0.3723, 0.3638, 1.2563, -0.1221, 0.1012, 0.4712,
0.6840, 0.5099, -0.7802, 0.6629, 0.6557, 0.0585, 0.7882,
-1.0858, 1.0520],
[-0.3893, 1.4754, -0.1709, -2.0884, 0.7963, 0.4962, 0.6029,
-0.5226, 1.0361, 0.5318, -0.3148, 0.0210, -0.0545, -0.8116,
-0.2611, -0.6926],
[ 1.5523, -2.3087, -2.1958, 0.3203, 0.7727, -0.1667, -0.0118,
-0.1128, -0.6838, -1.2514, -0.0760, 0.3789, 0.6201, -0.0899,
1.2097, 0.8767]],
[[ 1.8313, -0.6159, -0.6073, -2.0597, 1.5289, 0.3379, 0.1915,
0.1635, 0.6710, -0.4096, -0.5302, 0.2533, -0.1990, 0.6101,
-1.4391, 1.6621],
[ 0.3556, -1.8120, 0.4646, -0.5480, -1.0596, 0.1740, 0.3822,
-0.1958, -0.1513, 0.6256, -0.6219, -1.0873, -1.3252, 0.3772,
-0.0584, -1.4766],
[-0.9860, 1.4866, 0.1471, -1.3660, -0.6709, 0.9521, 1.4749,
-1.4756, -0.8660, 1.2781, 0.3526, -0.0750, 0.4059, 0.5351,
-0.0688, -0.6155],
[ 0.2696, -0.0316, -1.2757, -0.6373, -0.7616, -0.4670, -1.2028,
-2.4588, -0.4899, -1.5937, 0.9481, -0.4265, -1.4827, -0.4504,
0.8890, -1.1526],
[ 0.0295, -0.5199, -0.1654, -0.2773, -0.2447, -1.9880, -1.2664,
-0.3072, 0.8398, -0.4689, 0.2266, 0.3419, 0.5934, 1.9173,
-0.4787, -0.0578],
[-1.7239, -0.9909, 1.9552, -0.0653, 0.1463, 1.1357, -0.2689,
-0.9127, 0.6866, 1.5644, 1.0132, -1.1486, -0.7916, -0.3214,
0.5456, -1.2671],
[ 0.5780, -0.0210, -0.1380, 0.0994, -0.1628, 0.1898, -1.2572,
0.2571, -1.0626, -0.6326, -0.6293, -1.6768, 0.6724, 1.9889,
0.8157, -1.4683],
[ 1.6630, -1.4545, -0.2315, 0.5550, 0.3245, 1.4937, 0.5853,
0.7600, -1.0136, -1.3920, 0.8857, 0.9162, 0.4851, -1.0356,
0.1621, -0.3456]],
[[ 0.7719, 0.0167, 0.6804, -0.1298, 0.0973, 0.7957, -2.1607,
-0.5694, -2.0023, -1.2304, 0.8770, -2.0921, 1.5937, 2.5637,
-0.1268, 0.2314],
[ 0.7924, -0.3076, 0.6760, 2.6806, -0.8708, 0.0361, 1.0990,
-0.2800, 0.5311, 0.5320, -1.5853, 2.4220, 0.4772, 0.5957,
0.2793, 0.2393],
[ 0.4738, 0.0311, -0.1489, -0.3652, -1.8156, 1.1129, 1.1716,
-1.7179, 1.0240, -1.0366, -1.9978, 1.5088, 0.1965, 1.0685,
0.4851, 0.0060],
[ 1.0007, 0.7049, -0.6978, 0.4729, -0.6567, -0.8678, -0.1043,
0.9756, -0.8829, -0.7063, -1.2800, 0.1359, -0.2811, 1.7253,
0.1270, -0.8810],
[-0.6381, 0.5337, 0.1681, -1.0806, 0.8653, 0.9823, 0.7240,
0.1330, -0.6278, -0.1459, -0.4236, -1.4882, 0.8582, 3.0351,
-1.1488, 0.2271],
[ 0.0306, 0.0151, 1.1773, -0.9650, -0.2467, -0.6798, -1.0098,
-0.3883, -1.3796, 1.0700, -0.9035, 0.7684, 0.4392, -0.5033,
2.1168, 1.2190],
[-0.7853, 1.0901, -0.0665, 1.2573, 0.1582, -1.7430, -1.2939,
1.3075, 0.7086, 0.2949, -0.6938, -0.8013, -0.0776, -0.5015,
-2.2270, -0.1726],
[-0.6626, -0.5495, 0.0587, 1.5382, 1.0445, -0.2630, 0.2191,
0.0512, 1.1272, 0.5445, -0.2186, 0.4121, -1.1325, -2.3891,
0.7178, -1.5831]]])
# Create random queries with unit variance
q = torch.randn(B, T, head_size)
qOutput:
tensor([[[-9.6348e-01, -1.0543e+00, -6.1099e-01, 1.1033e-01, 1.2356e-01,
-1.4389e+00, -4.5936e-01, 7.1935e-01, -9.6226e-02, -6.8070e-01,
7.3392e-01, 9.3939e-02, 1.0835e+00, 8.0898e-01, -9.7732e-01,
-2.6084e-01],
[ 9.0191e-01, 3.1770e-01, 1.5054e+00, -4.5409e-04, -8.3999e-01,
-9.9635e-01, 1.9696e+00, -6.2411e-01, 7.8123e-01, -1.4737e+00,
9.1280e-01, -8.1394e-01, -3.2805e-01, -1.6034e+00, 1.5658e-01,
1.2400e+00],
[-1.3389e+00, -1.0444e-01, 1.5695e-01, -1.5132e+00, 9.9128e-01,
5.5732e-01, -6.7796e-01, 9.6848e-01, 8.3635e-01, -2.0765e+00,
9.2636e-01, 1.8823e+00, 2.7995e-02, -3.6298e-01, 4.5504e-01,
7.5949e-01],
[-9.6253e-01, 9.5393e-01, -1.4123e+00, 8.1285e-01, 1.4346e+00,
5.7747e-02, -8.9515e-01, -8.5902e-02, -6.0462e-01, -6.8750e-01,
2.0560e-01, -7.1922e-01, -1.1453e+00, 8.8890e-01, 2.4767e-01,
9.7610e-01],
[-1.0026e+00, -8.6914e-01, 1.0349e+00, 1.1414e+00, -6.1135e-01,
5.6699e-01, -1.5298e-01, -2.9166e-01, -1.2069e+00, -1.6842e-01,
-1.0213e+00, 4.5474e-01, 5.6282e-02, 1.9085e-01, -2.8136e-03,
-6.4238e-01],
[-2.3483e-01, 1.8348e-01, 8.2710e-01, 6.8175e-01, 4.0631e-01,
1.7062e+00, 1.1659e+00, -2.4008e-01, 2.2485e-01, -2.3762e+00,
4.0155e-01, -2.2946e+00, 9.5440e-01, -3.8834e-01, 2.1960e+00,
8.4125e-01],
[-1.4890e+00, 5.8501e-01, -6.4059e-01, -1.9064e+00, -2.1498e-01,
1.6726e-01, 8.5943e-02, -3.8008e-01, -1.3825e+00, 5.6673e-01,
-2.2063e+00, 2.8584e-01, 2.4994e+00, 5.4578e-02, -1.1838e+00,
8.2043e-01],
[ 7.9914e-01, 3.4314e-01, -7.1088e-01, 4.0654e-01, 9.5622e-01,
3.0749e-01, 3.1811e-01, -1.8298e+00, 1.8508e+00, -1.2886e+00,
1.2673e+00, -9.6881e-01, -4.6094e-01, 8.4074e-01, -1.9394e-01,
-1.4038e-01]],
[[ 7.0630e-02, -6.8063e-02, 1.2693e+00, 2.2910e+00, -7.9681e-02,
-2.1427e+00, 1.5941e+00, 2.5384e+00, -4.9620e-01, 1.5591e+00,
-1.2024e+00, -6.3514e-01, 2.0571e-01, 1.2085e+00, -1.2975e+00,
-1.7842e+00],
[-1.7387e+00, 1.7332e+00, 7.3354e-01, 9.9386e-01, 8.0198e-03,
-2.8346e-01, -1.0711e+00, -3.9535e-01, 7.2411e-01, 1.1350e+00,
1.3726e-01, -2.8324e-01, 8.4498e-01, -1.6652e+00, 9.0134e-01,
2.9201e-01],
[ 9.8307e-01, -1.6503e+00, -8.4474e-01, -1.7778e+00, 1.6798e+00,
-3.0757e-01, 1.7818e-01, 3.0598e-01, 6.4698e-01, 1.3001e+00,
6.4310e-01, -1.3192e+00, -5.6939e-01, -1.8996e+00, 2.5208e-01,
7.2486e-01],
[-1.0662e-02, 3.2636e-01, -3.9913e-01, -5.4581e-01, -9.4928e-01,
-5.8264e-01, 6.0848e-01, 1.8587e+00, -2.0494e+00, 1.2204e+00,
1.3751e+00, 4.9697e-01, -5.0501e-01, 1.3103e+00, -8.3440e-02,
2.1579e-01],
[ 1.1008e-01, -4.9934e-01, 1.4105e+00, 9.0974e-01, 8.3605e-01,
-4.5467e-01, -1.4347e-01, 1.0890e-01, 6.9148e-01, 4.3387e-01,
1.6470e-01, 9.2992e-01, 4.3113e-01, -2.2406e+00, 3.4794e-01,
-1.4979e+00],
[ 2.8999e-01, -2.8325e-01, -7.4868e-01, -1.4254e+00, -3.2493e-01,
2.4911e-01, -1.4054e+00, 7.2349e-03, -6.6038e-01, -9.2917e-01,
1.0411e+00, 1.7303e+00, -8.9635e-01, -1.7487e+00, 2.3790e-01,
-1.8056e+00],
[-4.1451e-01, -1.7113e+00, 3.6452e-01, -5.8035e-01, -7.1043e-01,
-2.7302e-01, 8.7178e-01, 2.1582e-01, 5.2196e-01, 3.4126e-01,
1.0109e+00, -4.6918e-02, 2.9928e-01, -4.2063e-01, -1.0377e+00,
6.3906e-01],
[ 2.3253e-02, -1.0358e+00, 6.9172e-01, -6.7895e-01, -9.2474e-01,
-3.1068e-01, 1.9666e-01, 6.3062e-01, -1.7229e+00, -5.2109e-01,
6.0717e-02, 4.2114e-01, 1.0577e+00, 1.3254e+00, 1.1332e+00,
4.2492e-01]],
[[ 7.7359e-01, 1.4078e-01, -2.7783e-01, 1.5577e-02, -9.0613e-01,
-6.0330e-01, 5.2168e-02, -5.9930e-01, -3.0289e+00, 2.6210e-01,
2.5721e+00, -4.7537e-01, 3.6165e-01, -5.7933e-01, -7.2539e-01,
1.3805e+00],
[-1.5331e-01, -1.3848e+00, 4.5417e-01, 3.8795e-01, -6.6643e-01,
-5.4208e-01, 2.9302e+00, 3.9807e-02, -6.9273e-01, -6.2174e-01,
7.4563e-01, -4.6407e-01, 3.5406e-01, 5.7339e-01, -1.9254e+00,
1.3461e+00],
[ 1.2138e+00, -8.6570e-02, -7.5610e-01, -6.0872e-01, 6.5236e-02,
2.0737e-01, -4.8440e-01, -4.9315e-01, -6.0584e-01, -3.8430e-01,
-1.5004e+00, 1.4241e+00, 4.2643e-01, 1.7253e+00, -1.1768e+00,
-5.5720e-01],
[ 6.2589e-01, -2.1850e-02, -1.7424e+00, -7.6729e-01, 1.4304e+00,
3.9235e-01, -1.3970e+00, -2.0021e+00, 2.3755e-01, 2.4901e+00,
3.1266e-01, -3.8830e-01, 8.9288e-01, 7.2496e-01, 4.3352e-01,
9.7441e-01],
[-4.0826e-03, 7.2984e-02, -2.5910e-01, 7.9609e-01, 5.7920e-01,
-1.3451e-01, -1.3556e+00, -3.9413e-01, 1.3683e+00, 1.2856e+00,
-1.3335e+00, 9.4325e-01, -1.1380e+00, -1.1207e+00, -6.6216e-01,
1.8776e+00],
[-9.0479e-01, -1.0662e-02, -3.0602e-01, -2.0311e+00, 2.8344e-01,
3.8604e-01, -2.3576e-01, -1.2253e+00, -7.9281e-01, 4.7013e-01,
9.6294e-02, 1.7162e+00, -1.5964e-01, -1.7001e-01, -8.2118e-01,
2.3731e-02],
[ 5.2973e-01, -3.7780e-01, 1.4174e+00, 4.9107e-01, 9.1596e-01,
1.7351e-01, 1.8969e-01, -1.5841e-01, -1.9692e+00, -2.4453e-01,
-6.7767e-01, 4.7823e-01, -1.8906e+00, 2.9901e-01, 1.5097e+00,
-4.0141e-01],
[ 3.6162e-01, 4.4447e-01, 3.7175e-01, 1.7289e-01, -1.2142e+00,
1.9442e+00, 1.1369e+00, -9.7678e-01, 9.7180e-01, 1.5757e-01,
-1.2937e-01, 9.4592e-01, 7.4863e-01, 1.3490e+00, 1.6593e+00,
-6.8208e-01]],
[[ 8.5920e-01, 1.0688e+00, 1.0354e+00, -1.0889e+00, -9.8336e-01,
1.8595e-01, 1.5627e+00, -1.3852e+00, -4.5413e-01, 6.9793e-01,
5.4690e-01, 5.5246e-01, 2.4029e-01, -1.1295e+00, -2.9120e-01,
-1.4457e+00],
[ 3.2800e-01, -2.2357e-01, 6.3385e-01, 1.2713e+00, 5.2576e-01,
-8.8591e-02, 8.3571e-02, -5.5072e-01, -4.5323e-01, 1.0592e+00,
-1.7509e+00, 8.3122e-01, -2.0036e-01, -1.2242e+00, -1.6658e-01,
-6.3705e-01],
[-8.2356e-01, -3.6673e-01, 5.7609e-01, -5.4015e-01, 2.1437e+00,
-8.4680e-01, 1.0048e+00, 9.9498e-02, -9.2466e-01, 6.7150e-01,
4.1903e-01, -2.2128e-01, 5.4170e-01, 2.7081e-01, 6.5265e-01,
-2.2563e-01],
[-2.7009e-01, 1.1325e+00, 1.2827e+00, 4.0248e-01, 1.5574e+00,
2.0379e+00, -1.1328e-01, -1.0065e+00, -3.7983e-01, -6.7844e-01,
-1.0526e+00, 2.2322e-01, 1.0225e+00, 7.7577e-01, 6.4597e-01,
7.4031e-01],
[-7.4956e-01, -1.1346e+00, 4.3102e-01, -2.2315e-01, 4.0467e-02,
3.5534e-01, 1.9383e+00, -1.2549e+00, 4.5961e-01, 1.5001e+00,
3.8992e-01, 4.0374e-01, -4.7906e-01, -5.7640e-01, -2.3239e+00,
-2.8545e-01],
[-5.0370e-01, 5.8248e-01, -2.6750e+00, 1.8529e-01, -1.3125e+00,
-7.7555e-01, -9.4621e-02, -1.1716e+00, 5.1212e-01, 3.4507e-01,
-1.1712e+00, 2.5597e-01, 4.5153e-01, -7.7741e-01, -2.5793e+00,
1.3328e+00],
[ 3.1002e-01, 8.0718e-02, 1.3723e-01, 1.2757e+00, 4.2629e-01,
1.2246e-01, 5.5749e-01, 3.0464e-01, -5.0844e-01, -3.8419e-01,
-1.8623e-01, -1.6419e-01, 5.0378e-01, 1.2022e+00, 5.4761e-01,
-5.0970e-01],
[ 1.6420e+00, -1.3696e+00, -1.3332e-01, 6.1232e-01, -2.7727e+00,
4.3604e-01, 1.2461e-01, 1.5180e+00, 6.8525e-01, 1.6848e+00,
-5.5461e-01, -3.9066e-01, 1.0325e+00, 1.6486e-01, -5.9153e-01,
-1.4602e-02]]])
# understand the scaling problem
print('understanding the scaling problem')
print()
print(f'head_size = {head_size}')
print()
print('we created random keys and queries with unit variance (~1)')
print(f' key variance: {k.var().item():.4f} (should be ~1)')
print(f' query variance: {q.var().item():.4f} (should be ~1)')
print()
print('when we multiply unit-variance vectors, variance grows')
print(' for dot product of two vectors with unit variance')
print(' the resulting variance ≈ length of vectors = head_size')
print()
print(f' expected variance after Q @ K^T: ~{head_size}')Output:
understanding the scaling problem
head_size = 16
we created random keys and queries with unit variance (~1)
key variance: 1.0392 (should be ~1)
query variance: 0.9791 (should be ~1)
when we multiply unit-variance vectors, variance grows
for dot product of two vectors with unit variance
the resulting variance ≈ length of vectors = head_size
expected variance after Q @ K^T: ~16
# demonstrate the problem: unscaled attention
print('demonstrating the problem: unscaled attention')
print()
wei_unscaled = q @ k.transpose(-2, -1)
print(f'wei_unscaled = q @ k.transpose(-2, -1)')
print(f'wei_unscaled variance: {wei_unscaled.var().item():.4f}')
print()
print(f'expected variance: ~{head_size}')
print(f'actual variance: {wei_unscaled.var().item():.2f}')
print()
print('the variance grew from ~1 to ~{head_size}!')
print('this is because dot product of two unit-variance vectors')
print(f'has variance equal to the length of the vectors ({head_size})')Output:
demonstrating the problem: unscaled attention
wei_unscaled = q @ k.transpose(-2, -1)
wei_unscaled variance: 12.4662
expected variance: ~16
actual variance: 12.47
the variance grew from ~1 to ~{head_size}!
this is because dot product of two unit-variance vectors
has variance equal to the length of the vectors (16)
# the solution: scaled attention
print('the solution: scaled attention')
print()
scale_factor = head_size ** -0.5
print(f'scale factor = 1 / sqrt(head_size) = 1 / sqrt({head_size}) = {scale_factor:.4f}')
print()
wei_scaled = (q @ k.transpose(-2, -1)) * scale_factor
print(f'wei_scaled = (q @ k.transpose(-2, -1)) * {scale_factor:.4f}')
print()
print(f'wei_scaled variance: {wei_scaled.var().item():.4f}')
print()
print('the variance is back to ~1!')
print()
print('why does this work?')
print(f' unscaled variance ≈ {head_size}')
print(f' multiplying by {scale_factor:.4f} = 1/sqrt({head_size})')
print(f' divides variance by {head_size} (because Var(aX) = a²Var(X))')
print(f' {head_size} * ({scale_factor:.4f})² = {head_size * scale_factor**2:.4f} ≈ 1')Output:
the solution: scaled attention
scale factor = 1 / sqrt(head_size) = 1 / sqrt(16) = 0.2500
wei_scaled = (q @ k.transpose(-2, -1)) * 0.2500
wei_scaled variance: 0.7791
the variance is back to ~1!
why does this work?
unscaled variance ≈ 16
multiplying by 0.2500 = 1/sqrt(16)
divides variance by 16 (because Var(aX) = a²Var(X))
16 * (0.2500)² = 1.0000 ≈ 1
When values are too large, softmax becomes "peaky" - almost all probability goes to one position. When values are small, softmax is "diffuse" - probability spreads more evenly.
For learning to work well, we need.
- Gradients to flow to multiple positions
- The network to learn which positions to attend to
- Smooth, not sharp, attention distributions early in training
# create example with small values
print('example: small values going into softmax')
print()
small_values = torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])
print(f'small_values = {small_values.tolist()}')
print()
print('these values have variance close to the expected ~1 from scaling')
print(f' variance: {small_values.var().item():.4f}')
print(f' max absolute value: {small_values.abs().max().item():.4f}')Output:
example: small values going into softmax
small_values = [0.10000000149011612, -0.20000000298023224, 0.30000001192092896, -0.20000000298023224, 0.5]
these values have variance close to the expected ~1 from scaling
variance: 0.0950
max absolute value: 0.5000
# apply softmax to small values
print('applying softmax to small values')
print()
softmax_small = torch.softmax(small_values, dim=-1)
print(f'input: {small_values.tolist()}')
print(f'softmax: {[round(x, 4) for x in softmax_small.tolist()]}')
print()
print('tracing through the softmax calculation')
import math
exp_vals = [math.exp(v.item()) for v in small_values]
exp_sum = sum(exp_vals)
print(f' e^0.1 = {exp_vals[0]:.4f}')
print(f' e^-0.2 = {exp_vals[1]:.4f}')
print(f' e^0.3 = {exp_vals[2]:.4f}')
print(f' e^-0.2 = {exp_vals[3]:.4f}')
print(f' e^0.5 = {exp_vals[4]:.4f}')
print(f' sum = {exp_sum:.4f}')
print()
print('softmax = e^x / sum(e^x)')
for i, (exp_v, soft_v) in enumerate(zip(exp_vals, softmax_small.tolist())):
print(f' position {i}: {exp_v:.4f} / {exp_sum:.4f} = {soft_v:.4f}')Output:
applying softmax to small values
input: [0.10000000149011612, -0.20000000298023224, 0.30000001192092896, -0.20000000298023224, 0.5]
softmax: [0.1925, 0.1426, 0.2351, 0.1426, 0.2872]
tracing through the softmax calculation
e^0.1 = 1.1052
e^-0.2 = 0.8187
e^0.3 = 1.3499
e^-0.2 = 0.8187
e^0.5 = 1.6487
sum = 5.7412
softmax = e^x / sum(e^x)
position 0: 1.1052 / 5.7412 = 0.1925
position 1: 0.8187 / 5.7412 = 0.1426
position 2: 1.3499 / 5.7412 = 0.2351
position 3: 0.8187 / 5.7412 = 0.1426
position 4: 1.6487 / 5.7412 = 0.2872
# analyze the small values softmax result
print('analyzing small values softmax result')
print()
print(f'softmax: {[round(x, 4) for x in softmax_small.tolist()]}')
print()
print('observations')
print(f' min probability: {softmax_small.min().item():.4f}')
print(f' max probability: {softmax_small.max().item():.4f}')
print(f' ratio max/min: {softmax_small.max().item() / softmax_small.min().item():.2f}x')
print()
print('the distribution is DIFFUSE (spread out)')
print(' - all positions get meaningful probability')
print(' - no single position dominates')
print(' - gradients can flow to all positions')
print(' - this is GOOD for learning!')Output:
analyzing small values softmax result
softmax: [0.1925, 0.1426, 0.2351, 0.1426, 0.2872]
observations
min probability: 0.1426
max probability: 0.2872
ratio max/min: 2.01x
the distribution is DIFFUSE (spread out)
- all positions get meaningful probability
- no single position dominates
- gradients can flow to all positions
- this is GOOD for learning!
# create example with large values
print('example: large values going into softmax')
print()
large_values = small_values * 8
print(f'large_values = small_values * 8')
print(f' = {large_values.tolist()}')
print()
print('this simulates what happens WITHOUT scaling')
print('if head_size = 64, values could be 8x larger')
print(f' variance: {large_values.var().item():.4f}')
print(f' max absolute value: {large_values.abs().max().item():.4f}')Output:
example: large values going into softmax
large_values = small_values * 8
= [0.800000011920929, -1.600000023841858, 2.4000000953674316, -1.600000023841858, 4.0]
this simulates what happens WITHOUT scaling
if head_size = 64, values could be 8x larger
variance: 6.0800
max absolute value: 4.0000
# apply softmax to large values
print('applying softmax to large values')
print()
softmax_large = torch.softmax(large_values, dim=-1)
print(f'input: {large_values.tolist()}')
print(f'softmax: {[round(x, 4) for x in softmax_large.tolist()]}')
print()
print('tracing through the softmax calculation')
exp_vals_large = [math.exp(v.item()) for v in large_values]
exp_sum_large = sum(exp_vals_large)
print(f' e^0.8 = {exp_vals_large[0]:.4f}')
print(f' e^-1.6 = {exp_vals_large[1]:.4f}')
print(f' e^2.4 = {exp_vals_large[2]:.4f}')
print(f' e^-1.6 = {exp_vals_large[3]:.4f}')
print(f' e^4.0 = {exp_vals_large[4]:.4f}')
print(f' sum = {exp_sum_large:.4f}')
print()
print('notice how e^4.0 = {:.2f} dominates the sum!'.format(exp_vals_large[4]))Output:
applying softmax to large values
input: [0.800000011920929, -1.600000023841858, 2.4000000953674316, -1.600000023841858, 4.0]
softmax: [0.0326, 0.003, 0.1615, 0.003, 0.8]
tracing through the softmax calculation
e^0.8 = 2.2255
e^-1.6 = 0.2019
e^2.4 = 11.0232
e^-1.6 = 0.2019
e^4.0 = 54.5982
sum = 68.2507
notice how e^4.0 = 54.60 dominates the sum!
# analyze the large values softmax result
print('analyzing large values softmax result')
print()
print(f'softmax: {[round(x, 4) for x in softmax_large.tolist()]}')
print()
print('observations')
print(f' min probability: {softmax_large.min().item():.6f}')
print(f' max probability: {softmax_large.max().item():.4f}')
print(f' ratio max/min: {softmax_large.max().item() / softmax_large.min().item():.0f}x')
print()
print('the distribution is PEAKY (concentrated)')
print(' - position 4 gets almost all probability (~0.96)')
print(' - other positions get almost nothing')
print(' - gradients only flow to position 4')
print(' - this is BAD for learning!')
print()
print('this is why we scale by 1/sqrt(head_size)')
print(' - keeps attention scores in a reasonable range')
print(' - prevents softmax from becoming too peaky')
print(' - allows gradients to flow during training')Output:
analyzing large values softmax result
softmax: [0.0326, 0.003, 0.1615, 0.003, 0.8]
observations
min probability: 0.002958
max probability: 0.8000
ratio max/min: 270x
the distribution is PEAKY (concentrated)
- position 4 gets almost all probability (~0.96)
- other positions get almost nothing
- gradients only flow to position 4
- this is BAD for learning!
this is why we scale by 1/sqrt(head_size)
- keeps attention scores in a reasonable range
- prevents softmax from becoming too peaky
- allows gradients to flow during training
Layer Normalization normalizes the features of each sample independently, ensuring they have mean=0 and std=1.
For each sample x with features [x₁, x₂, ..., xₙ].
- Compute mean: μ = (x₁ + x₂ + ... + xₙ) / n
- Compute variance: σ² = Σ(xᵢ - μ)² / n
- Normalize: x̂ᵢ = (xᵢ - μ) / √(σ² + ε)
- Scale and shift: yᵢ = γ * x̂ᵢ + β
Where.
- ε is a small constant for numerical stability (avoid divide by zero)
- γ (gamma) and β (beta) are learnable parameters
Without normalization, activations can.
- Explode (grow very large) → gradients explode
- Vanish (become very small) → gradients vanish
- Have very different scales → training instability
Layer normalization keeps activations in a stable range, making training faster and more reliable.
| Normalization | Normalizes across | Used in |
|---|---|---|
| BatchNorm | Same feature across batch samples (column-wise) | CNNs |
| LayerNorm | All features within one sample (row-wise) | Transformers |
LayerNorm is preferred in Transformers because.
- Works with any batch size (including batch_size=1 for inference)
- Each token's representation is normalized independently
- No dependency on other samples in the batch
class LayerNorm1d:
"""
Layer Normalization: normalizes each sample independently.
For input of shape (batch_size, features):
- Each row (sample) is normalized to have mean=0, std=1
- Then scaled by learnable gamma and shifted by learnable beta
"""
def __init__(self, dim, eps=1e-5):
"""
Initialize the LayerNorm1d layer.
Args:
dim: Number of features.
eps: Small constant for numerical stability (avoid divide by zero).
"""
self.eps = eps
# learnable parameters
self.gamma = torch.ones(dim) # scale (initialized to 1)
self.beta = torch.zeros(dim) # shift (initialized to 0)
def __call__(self, x):
"""
Forward pass.
Args:
x: Input tensor of shape (batch_size, dim).
Returns:
Normalized tensor of same shape.
"""
# step 1: compute mean of each sample (across features)
xmean = x.mean(dim=1, keepdim=True) # (batch, 1)
# step 2: compute variance of each sample
xvar = x.var(dim=1, keepdim=True) # (batch, 1)
# step 3: normalize to mean=0, var=1
xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
# step 4: scale and shift (learnable)
self.out = self.gamma * xhat + self.beta
return self.out
def parameters(self):
"""
Return learnable parameters.
Returns:
List containing gamma (scale) and beta (shift) tensors.
"""
return [self.gamma, self.beta]# understanding the LayerNorm class step by step
print('understanding the LayerNorm1d class')
print()
print('__init__(self, dim, eps=1e-5)')
print(' dim: number of features to normalize over')
print(' eps: small constant to avoid division by zero')
print(' gamma: learnable scale parameter, initialized to 1')
print(' beta: learnable shift parameter, initialized to 0')
print()
print('__call__(self, x)')
print(' step 1: compute mean across features (dim=1)')
print(' xmean = x.mean(dim=1, keepdim=True)')
print()
print(' step 2: compute variance across features (dim=1)')
print(' xvar = x.var(dim=1, keepdim=True)')
print()
print(' step 3: normalize to mean=0, variance=1')
print(' xhat = (x - xmean) / sqrt(xvar + eps)')
print()
print(' step 4: scale and shift (learnable)')
print(' out = gamma * xhat + beta')
print()
print('parameters()')
print(' returns [gamma, beta] for optimization')Output:
understanding the LayerNorm1d class
__init__(self, dim, eps=1e-5)
dim: number of features to normalize over
eps: small constant to avoid division by zero
gamma: learnable scale parameter, initialized to 1
beta: learnable shift parameter, initialized to 0
__call__(self, x)
step 1: compute mean across features (dim=1)
xmean = x.mean(dim=1, keepdim=True)
step 2: compute variance across features (dim=1)
xvar = x.var(dim=1, keepdim=True)
step 3: normalize to mean=0, variance=1
xhat = (x - xmean) / sqrt(xvar + eps)
step 4: scale and shift (learnable)
out = gamma * xhat + beta
parameters()
returns [gamma, beta] for optimization
# test layer normalization
torch.manual_seed(42)
# create the layer
n_features = 100
module = LayerNorm1d(n_features)
print(f'created LayerNorm1d with {n_features} features')
print(f' gamma shape: {module.gamma.shape}')
print(f' beta shape: {module.beta.shape}')
print()
# create test data
n_samples = 32
x = torch.randn(n_samples, n_features)
print(f'created test data with shape: {x.shape}')
print(f' {n_samples} samples, {n_features} features each')Output:
created LayerNorm1d with 100 features
gamma shape: torch.Size([100])
beta shape: torch.Size([100])
created test data with shape: torch.Size([32, 100])
32 samples, 100 features each
# examine the data BEFORE normalization
print('examining data BEFORE normalization')
print()
print('sample 0')
print(f' mean: {x[0].mean().item():.4f}')
print(f' std: {x[0].std().item():.4f}')
print(f' min: {x[0].min().item():.4f}')
print(f' max: {x[0].max().item():.4f}')
print()
print('sample 1')
print(f' mean: {x[1].mean().item():.4f}')
print(f' std: {x[1].std().item():.4f}')
print(f' min: {x[1].min().item():.4f}')
print(f' max: {x[1].max().item():.4f}')
print()
print('sample 2')
print(f' mean: {x[2].mean().item():.4f}')
print(f' std: {x[2].std().item():.4f}')
print(f' min: {x[2].min().item():.4f}')
print(f' max: {x[2].max().item():.4f}')
print()
print('note: means are not 0, stds are not exactly 1')
print('each sample has slightly different statistics')Output:
examining data BEFORE normalization
sample 0
mean: 0.0320
std: 1.0406
min: -2.5095
max: 2.2181
sample 1
mean: 0.0811
std: 0.9597
min: -2.4801
max: 2.0265
sample 2
mean: 0.0571
std: 0.9930
min: -2.5850
max: 2.1120
note: means are not 0, stds are not exactly 1
each sample has slightly different statistics
# apply layer normalization
x_normed = module(x)
print('applied layer normalization')
print(f' x_normed shape: {x_normed.shape}')Output:
applied layer normalization
x_normed shape: torch.Size([32, 100])
# examine the data AFTER normalization
print('examining data AFTER normalization')
print()
print('sample 0')
print(f' mean: {x_normed[0].mean().item():.6f}')
print(f' std: {x_normed[0].std().item():.4f}')
print(f' min: {x_normed[0].min().item():.4f}')
print(f' max: {x_normed[0].max().item():.4f}')
print()
print('sample 1')
print(f' mean: {x_normed[1].mean().item():.6f}')
print(f' std: {x_normed[1].std().item():.4f}')
print(f' min: {x_normed[1].min().item():.4f}')
print(f' max: {x_normed[1].max().item():.4f}')
print()
print('sample 2')
print(f' mean: {x_normed[2].mean().item():.6f}')
print(f' std: {x_normed[2].std().item():.4f}')
print(f' min: {x_normed[2].min().item():.4f}')
print(f' max: {x_normed[2].max().item():.4f}')
print()
print('now all samples have mean ≈ 0 and std ≈ 1!')
print('layer normalization standardizes each sample independently')Output:
examining data AFTER normalization
sample 0
mean: -0.000000
std: 1.0000
min: -2.4425
max: 2.1008
sample 1
mean: 0.000000
std: 1.0000
min: -2.6687
max: 2.0271
sample 2
mean: 0.000000
std: 1.0000
min: -2.6605
max: 2.0694
now all samples have mean ≈ 0 and std ≈ 1!
layer normalization standardizes each sample independently
# verify normalization for all samples
print('verifying normalization for all samples')
print()
means = x_normed.mean(dim=1)
stds = x_normed.std(dim=1)
print(f'mean of all sample means: {means.mean().item():.8f} (should be ~0)')
print(f'mean of all sample stds: {stds.mean().item():.4f} (should be ~1)')
print()
print('all samples are now normalized!')Output:
verifying normalization for all samples
mean of all sample means: 0.00000000 (should be ~0)
mean of all sample stds: 1.0000 (should be ~1)
all samples are now normalized!
Let's visualize what each normalization normalizes over.
For data of shape (batch_size, features) = (32, 100).
-
BatchNorm: Normalizes each FEATURE across all samples
- For feature 0: compute mean and std across all 32 samples
- Result: each feature has mean=0, std=1 across the batch
-
LayerNorm: Normalizes each SAMPLE across all features
- For sample 0: compute mean and std across all 100 features
- Result: each sample has mean=0, std=1 across its features
# demonstrate the difference between BatchNorm and LayerNorm
print('demonstrating the difference between BatchNorm and LayerNorm')
print()
print('our data has shape (32, 100) = (batch_size, features)')
print()
print('LAYERNORM normalizes across FEATURES (row-wise)')
print('checking ONE SAMPLE across all features')
print(f' sample 0 - mean: {x_normed[0,:].mean().item():.6f}, std: {x_normed[0,:].std().item():.4f}')
print(f' sample 1 - mean: {x_normed[1,:].mean().item():.6f}, std: {x_normed[1,:].std().item():.4f}')
print(f' sample 2 - mean: {x_normed[2,:].mean().item():.6f}, std: {x_normed[2,:].std().item():.4f}')
print(' → all samples are normalized (mean≈0, std≈1)')
print()
print('BATCHNORM would normalize across SAMPLES (column-wise)')
print('checking ONE FEATURE across all samples')
print(f' feature 0 - mean: {x_normed[:,0].mean().item():.4f}, std: {x_normed[:,0].std().item():.4f}')
print(f' feature 1 - mean: {x_normed[:,1].mean().item():.4f}, std: {x_normed[:,1].std().item():.4f}')
print(f' feature 2 - mean: {x_normed[:,2].mean().item():.4f}, std: {x_normed[:,2].std().item():.4f}')
print(' → features are NOT normalized (this is what BatchNorm would normalize)')
print()
print('LayerNorm normalizes each sample independently')
print('BatchNorm would normalize each feature across the batch')Output:
demonstrating the difference between BatchNorm and LayerNorm
our data has shape (32, 100) = (batch_size, features)
LAYERNORM normalizes across FEATURES (row-wise)
checking ONE SAMPLE across all features
sample 0 - mean: -0.000000, std: 1.0000
sample 1 - mean: 0.000000, std: 1.0000
sample 2 - mean: 0.000000, std: 1.0000
→ all samples are normalized (mean≈0, std≈1)
BATCHNORM would normalize across SAMPLES (column-wise)
checking ONE FEATURE across all samples
feature 0 - mean: -0.1428, std: 1.0629
feature 1 - mean: 0.0941, std: 1.1109
feature 2 - mean: 0.2947, std: 1.1421
→ features are NOT normalized (this is what BatchNorm would normalize)
LayerNorm normalizes each sample independently
BatchNorm would normalize each feature across the batch
# why LayerNorm is preferred in Transformers
print('why LayerNorm is preferred in Transformers')
print()
print('1. works with any batch size')
print(' - BatchNorm needs multiple samples to compute batch statistics')
print(' - LayerNorm only needs one sample (normalizes across features)')
print(' - important for inference with batch_size=1')
print()
print('2. each token is normalized independently')
print(' - in Transformers, each position is like a "sample"')
print(' - we want each token\'s representation to be well-conditioned')
print(' - no dependency on other samples in the batch')
print()
print('3. no running statistics needed')
print(' - BatchNorm maintains running mean/variance during training')
print(' - LayerNorm computes fresh statistics each forward pass')
print(' - simpler implementation and behavior')Output:
why LayerNorm is preferred in Transformers
1. works with any batch size
- BatchNorm needs multiple samples to compute batch statistics
- LayerNorm only needs one sample (normalizes across features)
- important for inference with batch_size=1
2. each token is normalized independently
- in Transformers, each position is like a "sample"
- we want each token's representation to be well-conditioned
- no dependency on other samples in the batch
3. no running statistics needed
- BatchNorm maintains running mean/variance during training
- LayerNorm computes fresh statistics each forward pass
- simpler implementation and behavior
# final summary of Part 6
print('SUMMARY: Full Self-Attention with Q, K, V')
print('=' * 60)
print()
print('SELF-ATTENTION MECHANISM')
print(' step 1: project input to queries, keys, values')
print(' Q = query(x) # "what am I looking for?"')
print(' K = key(x) # "what do I contain?"')
print(' V = value(x) # "what will I give?"')
print()
print(' step 2: compute attention scores')
print(' wei = Q @ K^T')
print(' wei[i,j] = how much position i attends to position j')
print()
print(' step 3: scale by 1/sqrt(head_size)')
print(' wei = wei * (head_size ** -0.5)')
print(' prevents variance explosion')
print()
print(' step 4: mask future positions')
print(' wei = wei.masked_fill(tril == 0, -inf)')
print(' prevents looking at future tokens')
print()
print(' step 5: apply softmax')
print(' wei = softmax(wei)')
print(' converts to probabilities that sum to 1')
print()
print(' step 6: weighted sum of values')
print(' out = wei @ V')
print(' each position gets weighted average of values')
print()
print('LAYER NORMALIZATION')
print(' normalizes each sample to mean=0, std=1')
print(' stabilizes training by keeping activations in good range')
print(' applied before attention and feedforward layers')
print()
print('COMPLETE ATTENTION FORMULA')
print(' Attention(Q, K, V) = softmax(mask(Q @ K^T / sqrt(d_k))) @ V')Output:
SUMMARY: Full Self-Attention with Q, K, V
============================================================
SELF-ATTENTION MECHANISM
step 1: project input to queries, keys, values
Q = query(x) # "what am I looking for?"
K = key(x) # "what do I contain?"
V = value(x) # "what will I give?"
step 2: compute attention scores
wei = Q @ K^T
wei[i,j] = how much position i attends to position j
step 3: scale by 1/sqrt(head_size)
wei = wei * (head_size ** -0.5)
prevents variance explosion
step 4: mask future positions
wei = wei.masked_fill(tril == 0, -inf)
prevents looking at future tokens
step 5: apply softmax
wei = softmax(wei)
converts to probabilities that sum to 1
step 6: weighted sum of values
out = wei @ V
each position gets weighted average of values
LAYER NORMALIZATION
normalizes each sample to mean=0, std=1
stabilizes training by keeping activations in good range
applied before attention and feedforward layers
COMPLETE ATTENTION FORMULA
Attention(Q, K, V) = softmax(mask(Q @ K^T / sqrt(d_k))) @ V
