Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Backend][Model] Blocksparse flash attention kernel and Phi-3-Small model #4799

Merged
merged 101 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
0e4c28d
vllm format w/ mha
beagleski Apr 8, 2024
b2e7c0a
original triton kernel
beagleski Apr 8, 2024
d5308c5
added prompt phase bs attn
linxihui Apr 9, 2024
d73cdb3
minor change
linxihui Apr 9, 2024
176275e
changes to make prompt with bs attn works
linxihui Apr 9, 2024
1116e01
wip for backup
linxihui Apr 10, 2024
24ab443
paged attn kernel with blocksparse support
linxihui Apr 11, 2024
b6c2ebe
some cleaning
linxihui Apr 11, 2024
4e28773
some cleaning
linxihui Apr 11, 2024
bccec2f
some cleaning
linxihui Apr 11, 2024
ab0df74
support tp
linxihui Apr 12, 2024
f11d590
fixed TP
linxihui Apr 12, 2024
1670a3d
flash2 logic + q broadcast
linxihui Apr 12, 2024
8531eaa
clean up
linxihui Apr 13, 2024
b20312c
split local
linxihui Apr 15, 2024
26c6222
split-local
linxihui Apr 15, 2024
0a52b2b
finished spliting local and stride
linxihui Apr 16, 2024
f85da14
clean up
linxihui Apr 16, 2024
0ee826b
added sparse support
linxihui Apr 16, 2024
3891f22
seem to work, but need binding and unit test.
linxihui Apr 16, 2024
1440eba
add binding
beagleski Apr 17, 2024
6571c58
add more backend interface; change is_sparse to be guarded by blocksp…
beagleski Apr 17, 2024
7143bac
refactor to phi3
beagleski Apr 18, 2024
a1f37a9
longrope support
linxihui Apr 18, 2024
8ff8be7
code cleaning; larger block_size
linxihui Apr 19, 2024
439c7c7
LongRoPE
linxihui Apr 19, 2024
bfba8d5
v100 support
linxihui Apr 23, 2024
9473082
Merge branch 'eric/cuda-kernel-longrope' into eric/bs-attn-longrope-spda
linxihui Apr 23, 2024
f4c53d3
bs backend for prompt
linxihui Apr 23, 2024
809f3f5
merge eric/cuda-kernle
linxihui Apr 23, 2024
7868f0a
not need to import flash in v100
linxihui Apr 23, 2024
7d92de3
folder/files re-org
linxihui Apr 23, 2024
ca27e7a
folder/files restructure
linxihui Apr 23, 2024
7c0cfd7
finished file/folder restructure
linxihui Apr 23, 2024
5389e35
added missing files
linxihui Apr 24, 2024
e5747f8
minor bug fixed for v100
linxihui Apr 24, 2024
c68ecb5
minor cleanup
linxihui Apr 24, 2024
0718c86
code cleaning
linxihui Apr 24, 2024
85b0ed5
used bf16 for cpu by default; rm lru_cache to prevent leaking; minor …
linxihui Apr 24, 2024
66b04d6
used env var to control if to use the Triton or cuda paged attn kernel
linxihui Apr 24, 2024
c39be85
rm file not long needed
linxihui Apr 24, 2024
b2df3f7
styling change
beagleski Apr 24, 2024
2ff8778
Run yapf and ruff
beagleski Apr 25, 2024
bb0ff75
clean up and add doc
linxihui Apr 25, 2024
e5c7212
fixed folder typo and preempted __init__ files
linxihui Apr 25, 2024
dbd6b47
rm phi3small config and tokenizer, as not needed; changed to meet for…
linxihui Apr 26, 2024
8eac29c
clean up; revert unnecessary change; run auto formatting
linxihui Apr 26, 2024
71663e8
clean up
linxihui Apr 26, 2024
0be4ce2
typo fix
linxihui Apr 26, 2024
aa65d2e
revert changes in vllm/utils.py
beagleski Apr 30, 2024
fd5486a
suppress dummy token ids if any
linxihui Apr 30, 2024
63b9bb8
minor fix for py3.8 on type annoton
linxihui Apr 30, 2024
8d3ec74
minor change for LongRoPE config to account for rename from longrope …
codedecde May 3, 2024
e1dd365
handling TP slicing on the vllm side for dummy tokens fix
codedecde May 3, 2024
561d5a8
Merge pull request #3 from beagleski/bapatra/bugfix-longrope-type
codedecde May 4, 2024
bfd3c80
patching for having type su
codedecde May 5, 2024
7646e00
Merge pull request #4 from beagleski/bapatra/patching-for-su
linxihui May 6, 2024
201c2c1
built sparse kernel separatedly at compile time
linxihui May 6, 2024
c1f7c26
updated ops header file for paged_attn; remove blocksparse from flash…
linxihui May 7, 2024
3db3010
merge rc1
linxihui May 7, 2024
55f0d4b
suppress warning on beta version of csr
linxihui May 7, 2024
36009b4
merged public vllm main branch; resolve config; change accordingly
linxihui May 7, 2024
0c6b10c
removed test as it is not working due to missing of public model id
linxihui May 7, 2024
a989bc4
fixed according to new change in vllm
linxihui May 8, 2024
a3efa6a
fixed bugs when sparse config is not the same phi3small, tp>1 when ve…
linxihui May 8, 2024
a26c269
formatted
linxihui May 8, 2024
d3f2943
fixed doc
linxihui May 10, 2024
525c48d
replace -inf to -9999 in prompt probs as it cannot be serialized as json
linxihui May 10, 2024
9b7d192
moved kernels and backends to vllm/attention; refactor phi3small
linxihui May 10, 2024
2eefeda
minor changes
linxihui May 10, 2024
90a0a87
merged public main
linxihui May 13, 2024
300797c
minor changes
linxihui May 13, 2024
d5cd48c
merged public main 5/13
linxihui May 13, 2024
04b3cdc
changed according to formatting requirment
linxihui May 13, 2024
83b23e5
removed unused import
linxihui May 13, 2024
87bd2ac
fixed formatting requirment
linxihui May 13, 2024
d6ea404
removed paged attn triton kernel as not used
linxihui May 13, 2024
8e86707
clean
linxihui May 13, 2024
69d412e
formating change
linxihui May 13, 2024
33a1930
Merge pull request #5 from beagleski/eric/bs-attn-and-phi3small
linxihui May 14, 2024
e9dc082
Merge branch 'main' of https://github.com/linxihui/vllm
linxihui May 14, 2024
dfc07c7
minor change
linxihui May 14, 2024
1197728
Update vllm/attention/ops/blocksparse_attention/interface.py
linxihui May 14, 2024
2955cec
Update vllm/attention/backends/blocksparse_attn.py
linxihui May 14, 2024
eb16d9a
fixed according to suggestion by @mgoin
linxihui May 14, 2024
1600156
added unittest
linxihui May 21, 2024
e7f9918
add unittest for blocksaprse
linxihui May 24, 2024
2afd8b1
reverted changes
linxihui May 24, 2024
def0c4c
refactored blocksparse code for better readability
linxihui May 24, 2024
359cc7f
used default values to blocksparse params to be compatible with old i…
linxihui May 24, 2024
6d0441b
default value to be consistent
linxihui May 24, 2024
52bf2b5
replaced pytorch sparse matrix utils to scipy.sparse to avoid annoyin…
linxihui May 24, 2024
97f3662
added unittest for blocksparse attn prefilling and blocksparse paged …
linxihui May 24, 2024
754e306
merged upstream/main; resolved conflict
linxihui May 24, 2024
c834882
clean; format
linxihui May 24, 2024
644fc14
updated metadata def
linxihui May 24, 2024
435dd38
run bash format.sh
linxihui May 24, 2024
8a22c26
clang-format
linxihui May 24, 2024
8554331
ruff fix
linxihui May 24, 2024
547692e
matched interface with gpu verison; throw error if blocksparse attn i…
linxihui May 25, 2024
daf94f3
run clang-format
linxihui May 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 148 additions & 37 deletions csrc/attention/attention_kernels.cu

Large diffs are not rendered by default.

37 changes: 21 additions & 16 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,14 +415,17 @@ void paged_attention_v1_impl_launcher(
}
} // namespace

void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& value_cache,
int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens,
int block_size, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale) {
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
Expand Down Expand Up @@ -726,16 +729,18 @@ void paged_attention_v2_impl_launcher(
}
} // namespace

void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads,
float scale, torch::Tensor& block_tables,
torch::Tensor& seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale) {
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
Expand Down
35 changes: 18 additions & 17 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@

#include <torch/extension.h>

void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& value_cache,
int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens,
int block_size, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale);

void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads,
float scale, torch::Tensor& block_tables,
torch::Tensor& seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale);
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step);

void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step);

void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
float epsilon);
Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ Alongside each architecture, we include some popular models that use it.
- Phi-3
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
-
* - :code:`Phi3SmallForCausalLM`
- Phi-3-Small
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
-
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
Expand Down
Loading
Loading