-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Optimize data movement #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
28df307
Merge QKV for OPT
zhuohan123 2e417f5
merge qkv for llama
zhuohan123 06f23ff
fix the code according to woosuk's comment
zhuohan123 da0fdd2
Merge branch 'main' into qkv_combined
zhuohan123 b1ba1e4
Merge branch 'main' into qkv_combined
WoosukKwon 9c5eca0
Add SiluAndMul
WoosukKwon a5719c1
Remove
WoosukKwon 07fb828
Merge branch 'activation' into qkv_combined
WoosukKwon 47622ec
Add SiluAndMul for fused SwiGLU
WoosukKwon c3816b8
Optimize data movement in attention
WoosukKwon b8d0024
Add activation_ops to setup.py
WoosukKwon 3f8dd53
Make rotary embedding in-place
WoosukKwon 07e2bca
Bug fix
WoosukKwon 8a37545
Roll back attention arguments
WoosukKwon 0417554
Merge branch 'main' into data-move
WoosukKwon e2a47cc
Fix test for reshape_and_cache
WoosukKwon df402bb
Fix test for rotary_embedding_neox
WoosukKwon 7152271
Fix test for attention kernels
WoosukKwon 0132133
Merge branch 'main' into data-move
WoosukKwon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from cacheflow import activation_ops | ||
|
||
|
||
class SiluAndMul(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, # (num_tokens, 2 * d) | ||
) -> torch.Tensor: # (num_tokens, d) | ||
num_tokens = x.shape[0] | ||
d = x.shape[1] // 2 | ||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) | ||
activation_ops.silu_and_mul(out, x) | ||
return out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#include <torch/extension.h> | ||
|
||
void silu_and_mul( | ||
torch::Tensor& out, | ||
torch::Tensor& input); | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def( | ||
"silu_and_mul", | ||
&silu_and_mul, | ||
"Activation function used in SwiGLU."); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#include <torch/extension.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
namespace cacheflow { | ||
|
||
template<typename T> | ||
__device__ __forceinline__ T silu(const T& x) { | ||
// x * sigmoid(x) | ||
return (T) (((float) x) / (1.0f + expf((float) -x))); | ||
} | ||
|
||
template<typename scalar_t> | ||
__global__ void silu_and_mul_kernel( | ||
scalar_t* __restrict__ out, // [num_tokens, d] | ||
const scalar_t* __restrict__ input, // [num_tokens, 2, d] | ||
const int d) { | ||
const int token_idx = blockIdx.x; | ||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { | ||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); | ||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); | ||
out[token_idx * d + idx] = silu(x) * y; | ||
} | ||
} | ||
|
||
} // namespace cacheflow | ||
|
||
void silu_and_mul( | ||
torch::Tensor& out, // [num_tokens, d] | ||
torch::Tensor& input) // [num_tokens, 2 * d] | ||
{ | ||
int num_tokens = input.size(0); | ||
int d = input.size(1) / 2; | ||
|
||
dim3 grid(num_tokens); | ||
dim3 block(std::min(d, 1024)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
input.scalar_type(), | ||
"silu_and_mul_kernel", | ||
[&] { | ||
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||
out.data_ptr<scalar_t>(), | ||
input.data_ptr<scalar_t>(), | ||
d); | ||
}); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, so flash attention natively supports non-contiguous QKV tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It actually requires qkv tensor of shape
[num_tokens, 3, num_heads, head_size]
. Previously, we insertedtorch.stack
to meet this shape requirement, and this PR eliminates this inefficiency.