Skip to content

docs: new kernel design #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
99 changes: 75 additions & 24 deletions tutorial/4 - flash attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,7 @@
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
"collapsed": false
},
"outputs": [],
"source": [
Expand All @@ -137,10 +134,10 @@
" Kj = K_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d\n",
" for block_start_Br in range(0, N, Br):\n",
" block_end_Br = block_start_Br + Br\n",
" Qi = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d\n",
" q = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d\n",
"\n",
" # QKt at the tile level\n",
" Sij = Qi @ Kj.T # shape Br x Bc\n",
" Sij = q @ Kj.T # shape Br x Bc\n",
" S_mat_for_check[block_start_Br:block_end_Br, block_start_Bc:block_end_Bc] += Sij\n",
"\n",
"assert torch.allclose(S_mat_for_check, Q_mat @ K_mat.T)"
Expand Down Expand Up @@ -170,10 +167,7 @@
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
"collapsed": false
},
"outputs": [],
"source": [
Expand All @@ -185,10 +179,10 @@
" Vj = V_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d\n",
" for block_start_Br in range(0, N, Br):\n",
" block_end_Br = block_start_Br + Br\n",
" Qi = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d\n",
" q = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d\n",
"\n",
" # QKt at the tile level\n",
" Sij = Qi @ Kj.T # shape Br x Bc\n",
" Sij = q @ Kj.T # shape Br x Bc\n",
" Oi = Sij @ Vj # shape Br x d\n",
" O[block_start_Br:block_end_Br, :] += Oi\n",
"\n",
Expand Down Expand Up @@ -243,16 +237,13 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
"collapsed": false
},
"outputs": [],
"source": [
"# variables outside the for loop represent the global memory\n",
"# they are the only ones bigger than what the SRAM can store\n",
"O = torch.zeros((N, d))\n",
"O.zero_() # reset output\n",
"\n",
"# For the 2 variables below, they may be removed in a serially executed code (in particular the outter for loop)\n",
"# They are needed in parallelized execution where each thread block need to sync its findings with the others\n",
Expand All @@ -273,10 +264,10 @@
" mi = m[block_start_Br:block_end_Br, :] # shape Br x 1\n",
" li = l[block_start_Br:block_end_Br, :] # shape Br x 1\n",
" Oi = O[block_start_Br:block_end_Br, :] # shape Br x d\n",
" Qi = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d\n",
" q = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d\n",
"\n",
" # line 9, QKt at the tile level\n",
" Sij = Qi @ Kj.T # shape Br x Bc\n",
" Sij = q @ Kj.T # shape Br x Bc\n",
"\n",
" # line 10, find max of each row of the current loaded block (and only this block)\n",
" mij_hat = torch.max(Sij, dim=1).values[:, None]\n",
Expand Down Expand Up @@ -307,18 +298,78 @@
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
"collapsed": false
},
"source": [
"## Triton implementation\n",
"\n",
"Triton implementation of Flash attention and original Flash attention Cuda implementations differ on an important point: the way they are parallelized.\n",
"\n",
"In Cuda implementation, it's quite simple, algorithm above is executed in a serialized way. The parallelization only happens at the `head x batch` level (so it needs on `A100` at least head x batch >= 80 to keep the GPU busy).\n",
"\n",
"In Triton implementation, the inner and outer loops in the algo above are switched and the parallelization happens at the level of the outer loop, it increases the level of parallelization and it makes the GPU busy even for small batches / low number of heads. See https://github.com/HazyResearch/flash-attention/issues/40 for detailed analysis."
"In Triton implementation, the inner and outer loops in the algo above are switched and the parallelization happens at the level of the outer loop, it increases the level of parallelization and it makes the GPU busy even for small batches / low number of heads. See https://github.com/HazyResearch/flash-attention/issues/40 for detailed analysis.\n",
"\n",
"You will find below a slightly simplified PyTorch transcription of the Triton kernel we use in this library.\n",
"All variable names are from the Triton Kernel, but those already declared in relation with tensor shapes."
]
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [
"O.zero_() # reset output\n",
"\n",
"for block_start_Bc in range(0, N, Bc): # <-- this loop is parallelized and is implicit in Triton implementation\n",
" block_end_Bc = block_start_Bc + Bc\n",
" q = Q_mat[block_start_Bc:block_end_Bc, :]\n",
" acc = torch.zeros(Bc, d)\n",
"\n",
" l_i = torch.zeros((Bc,), dtype=torch.float32) - float(\"inf\")\n",
" d_i = torch.zeros((Bc,), dtype=torch.float32)\n",
"\n",
" for block_start_Br in range(0, N, Br):\n",
" block_end_Br = block_start_Br + Br\n",
" k = K_mat[block_start_Br:block_end_Br, :]\n",
" qk = q @ k.T\n",
" l_j = torch.max(qk, dim=1).values\n",
"\n",
" numerators = torch.exp(qk - l_j[:, None])\n",
" d_j = torch.sum(numerators, 1)\n",
"\n",
" l_new = torch.maximum(l_i, l_j)\n",
" alpha = torch.exp(l_i - l_new)\n",
" beta = torch.exp(l_j - l_new)\n",
" d_new = alpha * d_i + beta * d_j\n",
"\n",
" p_scale = beta / d_new\n",
"\n",
" qk_softmax = numerators * p_scale[:, None]\n",
" acc_scale = d_i / d_new * alpha\n",
" acc = acc * acc_scale[:, None]\n",
"\n",
" v = V_mat[block_start_Br:block_end_Br, :]\n",
" acc += qk_softmax @ v\n",
" d_i = d_new\n",
" l_i = l_new\n",
"\n",
" O[block_start_Bc:block_end_Bc, :] += acc\n",
"\n",
"\n",
"assert torch.allclose(O, (torch.nn.functional.softmax(Q_mat @ K_mat.T, dim=1)) @ V_mat)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
Expand Down
152 changes: 152 additions & 0 deletions tutorial/5 - skinny flash attention.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(456)\n",
"\n",
"nrows, BLOCK_DHEAD = 16, 8\n",
"BLOCK_N = 2\n",
"n_segments = 4 # rename to segment number\n",
"assert nrows % n_segments == 0\n",
"nrows_per_segment = nrows // n_segments\n",
"assert nrows_per_segment % BLOCK_N == 0\n",
"nrows_per_block = nrows_per_segment // BLOCK_N\n",
"\n",
"Q_mat = torch.rand((1, BLOCK_DHEAD))\n",
"K_mat = torch.rand((nrows, BLOCK_DHEAD))\n",
"V_mat = torch.rand((nrows, BLOCK_DHEAD))\n",
"O_mat = torch.zeros_like(Q_mat)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# GM variables\n",
"d_j_split = torch.zeros((n_segments, nrows_per_block, Q_mat.shape[0]))\n",
"l_j_split = torch.zeros((n_segments, nrows_per_block, Q_mat.shape[0]))\n",
"acc_splitted = torch.zeros((n_segments, nrows_per_block, *O_mat.shape))\n",
"\n",
"q = Q_mat # load from GM\n",
"for segment_index in range(0, n_segments): # <- this loop is parallelized, it is implicit in the Triton implementation\n",
" for block_index_N, block_start_N in enumerate(range(0, nrows_per_segment, BLOCK_N)): # <- this loop is parallelized, it is implicit in the Triton implementation\n",
"\n",
" block_start_N += segment_index * nrows_per_segment\n",
" block_end_N = block_start_N + BLOCK_N\n",
"\n",
" k = K_mat[block_start_N:block_end_N, :] # load from GM\n",
" qk = q @ k.T\n",
" l_j = torch.max(qk, dim=1).values\n",
" l_j_split[segment_index, block_index_N, :] = l_j # saving to GM\n",
" numerators = torch.exp(qk - l_j[:, None]) # safe softmax numerator\n",
" d_j = torch.sum(numerators, dim=1)\n",
" d_j_split[segment_index, block_index_N, :] = d_j # saving to GM\n",
"\n",
" v = V_mat[block_start_N:block_end_N, :] # load from GM\n",
" o_segment = numerators @ v\n",
"\n",
" acc_splitted[segment_index, block_index_N, :, :] = o_segment # saving to GM\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0.4563, 0.6146, 0.5144, 0.3769, 0.5411, 0.6278, 0.4206, 0.4573]])\n",
"tensor([[0.4563, 0.6146, 0.5144, 0.3769, 0.5411, 0.6278, 0.4206, 0.4573]])\n"
]
}
],
"source": [
"# shared memory variables\n",
"l_i = torch.zeros((Q_mat.shape[0],)) - float(\"inf\")\n",
"d_i = torch.zeros((Q_mat.shape[0],))\n",
"acc = torch.zeros_like(O_mat)\n",
"\n",
"\n",
"for block_index_N in range(0, nrows // (n_segments * BLOCK_N)):\n",
" for segment_index in range(0, n_segments):\n",
"\n",
" acc_i = acc_splitted[segment_index, block_index_N]\n",
" # l_j = l_j_split[segment_index][block_index_N]\n",
" l_j = l_j_split[segment_index, block_index_N]\n",
" d_j = d_j_split[segment_index][block_index_N]\n",
"\n",
" l_new = torch.maximum(l_i, l_j)\n",
" alpha = torch.exp(l_i - l_new)\n",
" beta = torch.exp(l_j - l_new)\n",
" d_new = alpha * d_i + beta * d_j\n",
"\n",
" p_scale = beta / d_new\n",
"\n",
" acc_i *= p_scale[:, None]\n",
" acc_scale = d_i / d_new * alpha\n",
" # scaling factor is applied to the exported matrix\n",
" acc = acc * acc_scale[:, None]\n",
" acc += acc_i # accumulating in shared memory\n",
" d_i = d_new\n",
" l_i = l_new\n",
"\n",
"O_mat = acc # write to GM\n",
"\n",
"print(O_mat)\n",
"print((torch.nn.functional.softmax(Q_mat @ K_mat.T, dim=1)) @ V_mat)\n",
"assert torch.allclose(O_mat, (torch.nn.functional.softmax(Q_mat @ K_mat.T, dim=1)) @ V_mat)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "3.8.9"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}