Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

small improvement in flash attention #1732

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,41 @@ if (WITH_CUDA)
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
)

set_source_files_properties(
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
PROPERTIES COMPILE_FLAGS "--use_fast_math")
elseif(WITH_CUDNN)
message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
else()
Expand Down
12 changes: 7 additions & 5 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ namespace ctranslate2 {
const dim_t max_position_embeddings = 0,
const bool transpose = true);

void apply(StorageView& x, const dim_t offset = 0, bool apply = true);
void apply(StorageView& x, const dim_t offset = 0, bool fa2 = false);

StorageView& get_cos() {
return _cos;
StorageView& get_cos_half() {
return *_cos_half;
}

StorageView& get_sin() {
return _sin;
StorageView& get_sin_half() {
return *_sin_half;
}

bool get_interleave() const {
Expand Down Expand Up @@ -124,6 +124,8 @@ namespace ctranslate2 {

StorageView _sin;
StorageView _cos;
std::unique_ptr<StorageView> _sin_half;
std::unique_ptr<StorageView> _cos_half;
};


Expand Down
2 changes: 1 addition & 1 deletion include/ctranslate2/layers/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace ctranslate2 {
dim_t beam_size = 1);

const dim_t _cache_time_dim;
static constexpr dim_t _offset_free_space{100};
static constexpr dim_t _offset_free_space{512};
};
}
}
51 changes: 0 additions & 51 deletions include/ctranslate2/ops/flash-attention/philox.cuh

This file was deleted.

1 change: 0 additions & 1 deletion include/ctranslate2/ops/flash-attention/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

#include <cutlass/numeric_types.h>

#include "philox.cuh"
#include "utils.h"

#ifndef M_LOG2E
Expand Down
16 changes: 14 additions & 2 deletions src/layers/attention_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ namespace ctranslate2 {
_rotary_scaling_short_factor = std::make_unique<StorageView>(_rotary_scaling_short_factor->to(Device::CPU));
}

void RotaryEmbeddings::apply(StorageView& x, const dim_t offset, bool apply) {
void RotaryEmbeddings::apply(StorageView& x, const dim_t offset, bool fa2) {
const Device device = x.device();
const DataType dtype = x.dtype();
const dim_t max_time = _transpose ? x.dim(-2) : x.dim(-3);
Expand All @@ -211,8 +211,20 @@ namespace ctranslate2 {
const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0;
const dim_t new_num_positions = std::max(offset + max_time, cur_num_positions + _num_initial_positions);
initialize(new_num_positions, dim, device, dtype);
if (fa2) {
if (!_sin_half)
{
_sin_half = std::make_unique<StorageView>(dtype, device);
_cos_half = std::make_unique<StorageView>(dtype, device);
}
const ops::Slide slide_op(1, 0, dim / 2);
slide_op(_cos, *_cos_half);
slide_op(_sin, *_sin_half);
if (offset != 0)
return;
}
}
if (!apply)
if (offset != 0 && fa2)
return;

StorageView sin(dtype, device);
Expand Down
15 changes: 8 additions & 7 deletions src/layers/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@ namespace ctranslate2 {
}

void FlashMultiHeadAttention::operator()(const StorageView& queries,
const StorageView& values,
const StorageView&,
const StorageView* values_lengths,
StorageView& output,
StorageView* cached_keys,
StorageView* cached_values,
StorageView* attention,
const Padder* queries_padder,
const Padder* values_padder,
const Padder*,
bool return_normalized_attention,
StorageView* position_bias,
StorageView*,
dim_t offset) const {
PROFILE("MultiHeadAttention");
const Device device = queries.device();
const DataType dtype = queries.dtype();

Expand Down Expand Up @@ -63,8 +64,8 @@ namespace ctranslate2 {
}

if (_rotary_embeddings) {
_rotary_embeddings->apply(queries_proj, offset, offset == 0);
_rotary_embeddings->apply(keys_proj, offset, offset == 0);
_rotary_embeddings->apply(queries_proj, offset, true);
_rotary_embeddings->apply(keys_proj, offset, true);
}

if (cached_keys != nullptr) {
Expand Down Expand Up @@ -102,8 +103,8 @@ namespace ctranslate2 {
StorageView* rotary_sin = nullptr;
bool rotary_interleaved = false;
if (_rotary_embeddings && offset > 0) {
rotary_cos = &(_rotary_embeddings->get_cos());
rotary_sin = &(_rotary_embeddings->get_sin());
rotary_cos = &(_rotary_embeddings->get_cos_half());
rotary_sin = &(_rotary_embeddings->get_sin_half());
rotary_interleaved = _rotary_embeddings->get_interleave();
}

Expand Down
1 change: 1 addition & 0 deletions src/ops/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace ctranslate2 {
const bool rotary_interleave,
StorageView* alibi,
dim_t offset) const {
PROFILE("FlashAttention");
DEVICE_DISPATCH(queries.device(), compute<D>(queries, keys, values, output, cached_keys, cached_values,
attention, return_normalized_attention,
rotary_cos, rotary_sin, rotary_interleave, alibi, offset));
Expand Down
11 changes: 3 additions & 8 deletions src/ops/flash_attention_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,6 @@ namespace ctranslate2 {
dim_t offset) const {
const Device device = queries.device();
const DataType dtype = queries.dtype();
StorageView rotary_cos_half(dtype, device);
StorageView rotary_sin_half(dtype, device);

dim_t window_size_left = _sliding_window > 0 ? _sliding_window : -1;
dim_t window_size_right = _sliding_window > 0 ? 0 : -1;
Expand Down Expand Up @@ -324,12 +322,9 @@ namespace ctranslate2 {
params.is_seqlens_k_cumulative = false;

if (rotary_cos && rotary_sin) {
params.rotary_dim = rotary_cos->dim(1);
const ops::Slide slide_op(1, 0, params.rotary_dim / 2);
slide_op(*rotary_cos, rotary_cos_half);
slide_op(*rotary_sin, rotary_sin_half);
params.rotary_cos_ptr = rotary_cos_half.buffer();
params.rotary_sin_ptr = rotary_sin_half.buffer();
params.rotary_dim = rotary_cos->dim(1) * 2;
params.rotary_cos_ptr = rotary_cos->buffer();
params.rotary_sin_ptr = rotary_sin->buffer();
params.is_rotary_interleaved = rotary_interleave;
}
else
Expand Down
13 changes: 7 additions & 6 deletions tools/benchmark_tensor_parallel/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def process_prompt(generator, max_generation_length, generated_token, prompt):
step_results = generator.generate_tokens(
prompt,
max_length=max_generation_length,
sampling_temperature=0.6,
sampling_topk=20,
sampling_temperature=0.75,
sampling_topk=1,
sampling_topp=1,
)
for step_result in step_results:
Expand All @@ -77,13 +77,13 @@ def benchmark_generation(generator,
step_results = generator.generate_tokens(
prompt_tokens[i:i + batch_size],
max_length=max_generation_length,
sampling_temperature=0.6,
sampling_topk=20,
sampling_temperature=0.75,
sampling_topk=1,
sampling_topp=1,
)
for step_result in step_results:
batch_id = step_result.batch_id
generated_token[batch_id].append(step_result.token)
generated_token[i + batch_id].append(step_result.token)
end_all = time.time()
elapsed_time = end_all - start_all
num_tokens = count_tokens(generated_token)
Expand Down Expand Up @@ -148,7 +148,8 @@ def main():
args = parser.parse_args()

print("Loading the model...")
generator = ctranslate2.Generator(args.model_path, device="cuda", tensor_parallel=True, inter_threads=2)
generator = ctranslate2.Generator(args.model_path, device="cuda", tensor_parallel=True,
flash_attention=False, inter_threads=2)
sp = spm.SentencePieceProcessor(os.path.join(args.model_path, "tokenizer.model"))

if not os.path.exists(args.src):
Expand Down
Loading