Skip to content

Commit eec97b1

Browse files
committed
support ascend using infer_ext
1 parent 74e3dde commit eec97b1

File tree

12 files changed

+238
-229
lines changed

12 files changed

+238
-229
lines changed
Lines changed: 5 additions & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
3-
import torch
4-
import infer_ext.ops as ext_ops
5-
from torch import Tensor
62
from ..default import multinomial_sampling
3+
from .apply_rotary_pos_emb import apply_rotary_pos_emb
4+
from .fill_kv_cache import fill_kv_cache
5+
from .fused_rotary_emb import fused_rotary_emb
6+
from .paged_attention_fwd import paged_attention_fwd
7+
from .rms_norm import rms_norm
78

89
__all__ = [
910
'rms_norm',
@@ -13,213 +14,3 @@
1314
'paged_attention_fwd',
1415
'multinomial_sampling',
1516
]
16-
17-
def rms_norm(
18-
hidden_states: Tensor,
19-
weight: Tensor,
20-
epsilon: float = 1e-6
21-
):
22-
return ext_ops.rms_norm(hidden_states, weight, epsilon)
23-
24-
def apply_rotary_pos_emb(
25-
query_states: Tensor,
26-
key_states: Tensor,
27-
cos: Tensor,
28-
sin: Tensor,
29-
position_ids: Tensor,
30-
position_ids_1d: Tensor,
31-
q_embed=None,
32-
k_embed=None,
33-
context=None,
34-
):
35-
bs, head, dim = query_states.shape
36-
num_kv_heads = key_states.shape[1]
37-
query_states_reshaped = query_states.reshape(1, bs, head, dim)
38-
key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim)
39-
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
40-
cos = cos[position_ids_1d].view(1, bs, 1, -1)
41-
sin = sin[position_ids_1d].view(1, bs, 1, -1)
42-
if context:
43-
setattr(context, 'cos', cos)
44-
setattr(context, 'sin', sin)
45-
cached_cos = context.cos if context else cos
46-
cached_sin = context.sin if context else sin
47-
ext_ops.apply_rotary_pos_emb(
48-
query_states_reshaped, key_states_reshaped, cached_cos, cached_sin,
49-
None, None, None
50-
)
51-
if q_embed is None:
52-
q_embed = query_states
53-
else:
54-
q_embed.copy_(query_states)
55-
if k_embed is None:
56-
k_embed = key_states
57-
else:
58-
k_embed.copy_(key_states)
59-
return q_embed, k_embed
60-
61-
def fused_rotary_emb(
62-
query_states: Tensor,
63-
key_states: Tensor,
64-
position_ids: torch.LongTensor,
65-
inv_freq: Tensor,
66-
scaling_factor: float,
67-
out_q: Tensor = None,
68-
out_k: Tensor = None,
69-
context=None,
70-
):
71-
batch, seqlen, head, dim = query_states.shape
72-
num_kv_heads = key_states.shape[-2]
73-
query_states_reshaped = query_states.view(batch, seqlen, head, dim)
74-
key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim)
75-
position_ids = position_ids.squeeze(0).unsqueeze(-1)
76-
pos_freq = position_ids / scaling_factor * inv_freq
77-
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
78-
cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1)
79-
.repeat(1, 1, 1, 2).to(query_states.dtype))
80-
sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1)
81-
.repeat(1, 1, 1, 2).to(query_states.dtype))
82-
if context:
83-
setattr(context, 'cos', cos)
84-
setattr(context, 'sin', sin)
85-
cached_cos = context.cos if context else cos
86-
cached_sin = context.sin if context else sin
87-
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
88-
cached_cos, cached_sin, None, None, None)
89-
if out_q is None:
90-
out_q = query_states
91-
else:
92-
out_q.copy_(query_states)
93-
if out_k is None:
94-
out_k = key_states
95-
else:
96-
out_k.copy_(key_states)
97-
return out_q, out_k
98-
99-
def fill_kv_cache(
100-
key_states: Tensor,
101-
value_states: Tensor,
102-
key_caches: Tensor,
103-
value_caches: Tensor,
104-
q_start_loc: Tensor,
105-
q_seq_length: Tensor,
106-
kv_seq_length: Tensor,
107-
max_q_seq_length: int,
108-
block_offsets: Tensor,
109-
context: None,
110-
):
111-
"""fill key/value state to cache for paged attention."""
112-
ext_ops.fill_kv_cache(key_states, value_states, key_caches,
113-
value_caches, context.kv_start_indices)
114-
115-
def flash_context_attention(
116-
query_states: Tensor,
117-
key_states: Tensor,
118-
value_states: Tensor,
119-
attn_output: Tensor,
120-
key_cache: Tensor,
121-
value_cache: Tensor,
122-
block_offsets: Tensor,
123-
q_start_loc: Tensor,
124-
q_seq_len: Tensor,
125-
kv_seq_len: Tensor,
126-
block_size: int,
127-
kv_cache_len: int,
128-
context=None,
129-
):
130-
num_q_heads, dim = query_states.shape[1:3]
131-
num_kv_heads = value_states.shape[1]
132-
batch = q_start_loc.shape[0]
133-
134-
for i in range(batch):
135-
if torch.equal(q_seq_len[i], kv_seq_len[i]):
136-
ext_ops.context_attention(
137-
attn_output,
138-
query_states,
139-
key_states,
140-
value_states,
141-
q_start_loc[i:i+1],
142-
q_seq_len[i:i+1],
143-
num_q_heads,
144-
num_kv_heads,
145-
context.attention_mask[i:i+1],
146-
)
147-
else:
148-
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
149-
value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
150-
ext_ops.paged_prefill_attention(
151-
attn_output,
152-
query_states,
153-
key_cache,
154-
value_cache,
155-
block_offsets,
156-
block_size,
157-
q_start_loc[i:i+1],
158-
q_seq_len[i:i+1],
159-
kv_seq_len[i:i+1],
160-
num_q_heads,
161-
num_kv_heads,
162-
context.attention_mask[i:i+1],
163-
)
164-
165-
def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
166-
block_offsets, block_size):
167-
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
168-
ext_ops.paged_decode_attention(
169-
attn_output.view(q.shape),
170-
q,
171-
k_cache,
172-
v_cache,
173-
block_offsets,
174-
block_size,
175-
kv_seq_len,
176-
num_q_heads,
177-
num_kv_heads,
178-
)
179-
180-
def paged_attention_fwd(
181-
query_states: Tensor,
182-
key_states: torch.Tensor,
183-
value_states: torch.Tensor,
184-
key_cache: Tensor,
185-
value_cache: Tensor,
186-
attn_output: Tensor,
187-
block_offsets: Tensor,
188-
q_start_loc: Tensor,
189-
q_seqlens: Tensor,
190-
kv_seqlens: Tensor,
191-
max_seqlen: int,
192-
window_size: int = 1,
193-
context=None,
194-
):
195-
is_decoding = query_states.shape[-3] == q_seqlens.size(0)
196-
block_num, block_size, head, dim = key_cache.size()
197-
kv_cache_len = block_num * block_size
198-
k = key_cache.reshape(block_num * block_size, head, dim)
199-
v = value_cache.reshape(block_num * block_size, head, dim)
200-
if not is_decoding:
201-
flash_context_attention(
202-
query_states,
203-
key_states,
204-
value_states,
205-
attn_output,
206-
k,
207-
v,
208-
block_offsets,
209-
q_start_loc,
210-
q_seqlens,
211-
kv_seqlens,
212-
block_size,
213-
kv_cache_len,
214-
context=context,
215-
)
216-
else:
217-
paged_token_attention(
218-
query_states,
219-
k,
220-
v,
221-
attn_output,
222-
kv_seqlens,
223-
block_offsets,
224-
block_size,
225-
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import infer_ext.ops as ext_ops
3+
from torch import Tensor
4+
5+
6+
def apply_rotary_pos_emb(
7+
query_states: Tensor,
8+
key_states: Tensor,
9+
cos: Tensor,
10+
sin: Tensor,
11+
position_ids: Tensor,
12+
position_ids_1d: Tensor,
13+
q_embed=None,
14+
k_embed=None,
15+
context=None,
16+
):
17+
bs, head, dim = query_states.shape
18+
num_kv_heads = key_states.shape[1]
19+
query_states_reshaped = query_states.reshape(1, bs, head, dim)
20+
key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim)
21+
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
22+
cos = cos[position_ids_1d].view(1, bs, 1, -1)
23+
sin = sin[position_ids_1d].view(1, bs, 1, -1)
24+
if context:
25+
setattr(context, 'cos', cos)
26+
setattr(context, 'sin', sin)
27+
cached_cos = context.cos if context else cos
28+
cached_sin = context.sin if context else sin
29+
ext_ops.apply_rotary_pos_emb(
30+
query_states_reshaped, key_states_reshaped, cached_cos, cached_sin,
31+
None, None, None
32+
)
33+
if q_embed is None:
34+
q_embed = query_states
35+
else:
36+
q_embed.copy_(query_states)
37+
if k_embed is None:
38+
k_embed = key_states
39+
else:
40+
k_embed.copy_(key_states)
41+
return q_embed, k_embed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import infer_ext.ops as ext_ops
3+
from torch import Tensor
4+
5+
6+
def fill_kv_cache(
7+
key_states: Tensor,
8+
value_states: Tensor,
9+
key_caches: Tensor,
10+
value_caches: Tensor,
11+
q_start_loc: Tensor,
12+
q_seq_length: Tensor,
13+
kv_seq_length: Tensor,
14+
max_q_seq_length: int,
15+
block_offsets: Tensor,
16+
context: None,
17+
):
18+
"""fill key/value state to cache for paged attention."""
19+
ext_ops.fill_kv_cache(key_states, value_states, key_caches,
20+
value_caches, context.kv_start_indices)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import infer_ext.ops as ext_ops
3+
import torch
4+
from torch import Tensor
5+
6+
7+
def fused_rotary_emb(
8+
query_states: Tensor,
9+
key_states: Tensor,
10+
position_ids: torch.LongTensor,
11+
inv_freq: Tensor,
12+
scaling_factor: float,
13+
out_q: Tensor = None,
14+
out_k: Tensor = None,
15+
context=None,
16+
):
17+
batch, seqlen, head, dim = query_states.shape
18+
num_kv_heads = key_states.shape[-2]
19+
query_states_reshaped = query_states.view(batch, seqlen, head, dim)
20+
key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim)
21+
position_ids = position_ids.squeeze(0).unsqueeze(-1)
22+
pos_freq = position_ids / scaling_factor * inv_freq
23+
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
24+
cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1)
25+
.repeat(1, 1, 1, 2).to(query_states.dtype))
26+
sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1)
27+
.repeat(1, 1, 1, 2).to(query_states.dtype))
28+
if context:
29+
setattr(context, 'cos', cos)
30+
setattr(context, 'sin', sin)
31+
cached_cos = context.cos if context else cos
32+
cached_sin = context.sin if context else sin
33+
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
34+
cached_cos, cached_sin, None, None, None)
35+
if out_q is None:
36+
out_q = query_states
37+
else:
38+
out_q.copy_(query_states)
39+
if out_k is None:
40+
out_k = key_states
41+
else:
42+
out_k.copy_(key_states)
43+
return out_q, out_k

0 commit comments

Comments
 (0)