diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 2b8988313a0..96b654d0e5a 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -292,8 +292,7 @@ def attention( ) out = torch.empty_like(q) - - return flash_attn_cuda.fwd( + flash_attn_cuda.fwd( q, k, v, @@ -309,4 +308,5 @@ def attention( False, 0, None, - )[0] + ) + return out