Skip to content

Commit 27daf69

Browse files
authored
Merge branch 'huggingface:main' into vb/followup-doc-fixes
2 parents 03bfff5 + c6d5039 commit 27daf69

File tree

7 files changed

+518
-44
lines changed

7 files changed

+518
-44
lines changed

flake.lock

Lines changed: 99 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
{
2+
inputs = {
3+
tgi-nix.url = "github:danieldk/tgi-nix";
4+
nixpkgs.follows = "tgi-nix/nixpkgs";
5+
flake-utils.url = "github:numtide/flake-utils";
6+
};
7+
outputs =
8+
{
9+
self,
10+
nixpkgs,
11+
flake-utils,
12+
tgi-nix,
13+
}:
14+
flake-utils.lib.eachDefaultSystem (
15+
system:
16+
let
17+
config = {
18+
allowUnfree = true;
19+
cudaSupport = true;
20+
};
21+
pkgs = import nixpkgs {
22+
inherit config system;
23+
overlays = [ tgi-nix.overlay ];
24+
};
25+
in
26+
{
27+
devShells.default =
28+
with pkgs;
29+
mkShell {
30+
buildInputs =
31+
[
32+
cargo
33+
openssl.dev
34+
pkg-config
35+
]
36+
++ (with python3.pkgs; [
37+
venvShellHook
38+
pip
39+
40+
einops
41+
fbgemm-gpu
42+
flash-attn
43+
flash-attn-layer-norm
44+
flash-attn-rotary
45+
grpc-interceptor
46+
grpcio-reflection
47+
grpcio-status
48+
hf-transfer
49+
loguru
50+
marlin-kernels
51+
opentelemetry-api
52+
opentelemetry-exporter-otlp
53+
opentelemetry-instrumentation-grpc
54+
opentelemetry-semantic-conventions
55+
peft
56+
tokenizers
57+
torch
58+
transformers
59+
vllm
60+
]);
61+
62+
venvDir = "./.venv";
63+
64+
postVenv = ''
65+
unset SOURCE_DATE_EPOCH
66+
'';
67+
postShellHook = ''
68+
unset SOURCE_DATE_EPOCH
69+
'';
70+
};
71+
}
72+
);
73+
}

server/text_generation_server/layers/attention/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from dataclasses import dataclass
2-
from text_generation_server.models.globals import FLASH_DECODING
2+
from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER
33
import torch
44
from typing import Optional
55

66

7-
if FLASH_DECODING:
7+
if FLASH_DECODING or FLASH_INFER:
88

99
@dataclass
1010
class Seqlen:

server/text_generation_server/layers/attention/cuda.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import torch
22
from text_generation_server.utils.import_utils import SYSTEM
3-
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
3+
from text_generation_server.models.globals import (
4+
FLASH_DECODING,
5+
BLOCK_SIZE,
6+
FLASH_INFER,
7+
)
48
from text_generation_server.layers.attention import Seqlen
59
from typing import Optional
610

@@ -23,7 +27,7 @@ def reshape_and_cache(
2327
value_cache: torch.Tensor,
2428
slots: torch.Tensor,
2529
):
26-
if FLASH_DECODING:
30+
if FLASH_DECODING or FLASH_INFER:
2731
shape = key_cache.shape
2832
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
2933
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
@@ -72,7 +76,16 @@ def paged_attention(
7276
# V1 to avoid the overhead of reduction. Also, if the number of
7377
# sequences or heads is large, we use V1 since there is enough work
7478
# to parallelize.
75-
if FLASH_DECODING:
79+
if FLASH_INFER:
80+
from text_generation_server.layers.attention.flash_infer import decode_state
81+
82+
return decode_state.get().forward(
83+
query.contiguous(),
84+
paged_kv_cache=(key_cache, value_cache),
85+
logits_soft_cap=softcap,
86+
sm_scale=softmax_scale,
87+
)
88+
elif FLASH_DECODING:
7689
max_q = 1
7790
max_k = max_s
7891
import flash_attn_2_cuda
@@ -206,7 +219,32 @@ def paged_attention(
206219

207220
SUPPORTS_WINDOWING = V2
208221

209-
if V2:
222+
if FLASH_INFER:
223+
224+
def attention(
225+
q,
226+
k,
227+
v,
228+
cu_seqlens,
229+
max_s,
230+
softmax_scale,
231+
window_size_left=-1,
232+
causal=True,
233+
softcap=0.0,
234+
):
235+
from text_generation_server.layers.attention.flash_infer import prefill_state
236+
237+
return prefill_state.get().forward(
238+
q,
239+
k,
240+
v,
241+
causal=causal,
242+
window_left=window_size_left,
243+
logits_soft_cap=softcap,
244+
sm_scale=softmax_scale,
245+
)
246+
247+
elif V2:
210248

211249
def attention(
212250
q,

0 commit comments

Comments
 (0)