Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kv_cache to LLM #244

Merged
merged 12 commits into from
Mar 21, 2023
70 changes: 59 additions & 11 deletions examples/llm/src/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@ def scaled_multihead_dot_product_attention(
n_heads,
softmax_scale=None,
attn_bias=None,
query_padding_mask=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
training=False,
needs_weights=False,
):
if query_padding_mask is not None:
query = query.masked_fill(~query_padding_mask.unsqueeze(-1), 0)
if key_padding_mask is not None:
key = key.masked_fill(~key_padding_mask.unsqueeze(-1), 0)
value = value.masked_fill(~key_padding_mask.unsqueeze(-1), 0)

q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
k = rearrange(key, 'b s (h d) -> b h d s', h=n_heads) # includes key.t()
v = rearrange(value, 'b s (h d) -> b h s d', h=n_heads)
Expand All @@ -50,6 +57,9 @@ def scaled_multihead_dot_product_attention(
)
attn_weight = attn_weight + attn_bias

if query_padding_mask is not None:
attn_weight = attn_weight.masked_fill(
~query_padding_mask.view(b, 1, s_q, 1), -float('inf'))
if key_padding_mask is not None:
attn_weight = attn_weight.masked_fill(
~key_padding_mask.view((b, 1, 1, s_k)), -float('inf'))
Expand All @@ -65,14 +75,21 @@ def scaled_multihead_dot_product_attention(

attn_weight = torch.softmax(attn_weight, dim=-1)

if query_padding_mask is not None:
attn_weight = attn_weight.masked_fill(
~query_padding_mask.view(b, 1, s_q, 1), 0)
if key_padding_mask is not None:
attn_weight = attn_weight.masked_fill(
~key_padding_mask.view(b, 1, 1, s_k), 0)

if dropout_p:
attn_weight = torch.nn.functional.dropout(attn_weight,
p=dropout_p,
training=training,
inplace=True)

out = attn_weight.matmul(v)
out = rearrange(out, 'b h s d -> b s (h d)', h=n_heads)
out = rearrange(out, 'b h s d -> b s (h d)')

if needs_weights:
return out, attn_weight
Expand All @@ -94,6 +111,7 @@ def flash_attn_fn(
n_heads,
softmax_scale=None,
attn_bias=None,
query_padding_mask=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
Expand All @@ -113,15 +131,13 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

if query_padding_mask is None:
query_padding_mask = torch.ones_like(query[:, :, 0], dtype=torch.bool)
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)

if training:
pad_mask = key_padding_mask
else:
pad_mask = torch.ones_like(query[:, :, 0], dtype=torch.bool)
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query, pad_mask)
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
Expand Down Expand Up @@ -158,6 +174,7 @@ def triton_flash_attn_fn(
n_heads,
softmax_scale=None,
attn_bias=None,
query_padding_mask=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
Expand All @@ -179,16 +196,31 @@ def triton_flash_attn_fn(
raise NotImplementedError(
f'attn_impl: triton cannot return attn weights.')

if query_padding_mask is not None:
query = query.masked_fill(~query_padding_mask.unsqueeze(-1), 0)
if key_padding_mask is not None:
b_size, s_k = key_padding_mask.shape
key = key.masked_fill(~key_padding_mask.unsqueeze(-1), 0)
value = value.masked_fill(~key_padding_mask.unsqueeze(-1), 0)

if query_padding_mask is not None or key_padding_mask is not None:
b_size, s_q, s_k = query.size(0), 1, 1
if query_padding_mask is not None:
s_q = query_padding_mask.size(1)
if key_padding_mask is not None:
s_k = key_padding_mask.size(1)

if attn_bias is not None:
attn_bias = attn_bias.expand(b_size, -1, -1, -1)
else:
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
attn_bias = query.new_zeros(b_size, 1, s_q, s_k)
dskhudia marked this conversation as resolved.
Show resolved Hide resolved

attn_bias = attn_bias.masked_fill(
~key_padding_mask.view((b_size, 1, 1, s_k)), -float('inf'))
if query_padding_mask is not None:
attn_bias = attn_bias.masked_fill(
~query_padding_mask.view((b_size, 1, s_q, 1)), -float('inf'))

if key_padding_mask is not None:
attn_bias = attn_bias.masked_fill(
~key_padding_mask.view((b_size, 1, 1, s_k)), -float('inf'))

query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
Expand All @@ -199,6 +231,9 @@ def triton_flash_attn_fn(
softmax_scale)

output = attn_output.view(*attn_output.shape[:2], -1)
if query_padding_mask is not None:
output = output.masked_fill(~query_padding_mask.unsqueeze(-1), 0)

return output, None


Expand Down Expand Up @@ -256,6 +291,7 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):

def forward(self,
x,
past_key_value=None,
vchiley marked this conversation as resolved.
Show resolved Hide resolved
attn_bias=None,
key_padding_mask=None,
is_causal=True,
Expand All @@ -266,12 +302,23 @@ def forward(self,

query, key, value = qkv.chunk(3, dim=2)

query_padding_mask = None
if key_padding_mask is not None:
query_padding_mask = key_padding_mask[:, -query.size(1):]

if self.attn_qk_ln:
# Applying layernorm to qk
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)

if past_key_value is not None:
if len(past_key_value) == 0:
key = torch.cat([past_key_value[0], key], dim=1)
value = torch.cat([past_key_value[1], value], dim=1)

past_key_value = (key, value)

if attn_bias is not None:
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]

Expand All @@ -282,14 +329,15 @@ def forward(self,
self.n_heads,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
query_padding_mask=query_padding_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
)

return self.out_proj(context), attn_weights
return self.out_proj(context), attn_weights, past_key_value


def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, causal):
Expand Down
16 changes: 9 additions & 7 deletions examples/llm/src/models/layers/gpt_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""GPT Blocks used for the GPT Model."""

from typing import Optional
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,17 +48,19 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):
def forward(
self,
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
a = self.ln_1(x)
b, _ = self.attn(a,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal)
b, _, past_key_value = self.attn(a,
past_key_value=past_key_value,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal)
x = x + self.resid_attn_dropout(b)
m = self.ln_2(x)
n = self.mlp(m)
x = x + self.resid_mlp_dropout(n)
return x
return x, past_key_value
47 changes: 37 additions & 10 deletions examples/llm/src/models/mosaic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import math
import warnings
from typing import Optional
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -131,9 +131,11 @@ def _attn_bias(self, device, dtype):

return self.attn_bias

def forward(self,
input_ids: torch.LongTensor,
key_padding_mask: Optional[torch.ByteTensor] = None):
def forward(
self,
input_ids: torch.LongTensor,
key_padding_mask: Optional[torch.ByteTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None):
S = input_ids.size(1)
assert (
S <= self.cfg.max_seq_len
Expand All @@ -143,7 +145,25 @@ def forward(self,
if self.alibi:
x = tok_emb
else:
pos = torch.arange(0, S, dtype=torch.long,
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.cfg.n_layers:
raise ValueError(
f'past_key_values must provide a past_key_value for each attention ' +\
f'layer in the network ({len(past_key_values)=}; {self.cfg.n_layers=}).'
)
# get the key tensor whose spec should be (batch, seq, dim), and
# collect the `seq`, so that the position embedding is shifted
past_position = past_key_values[0][0].size(1)

if S + past_position > self.cfg.max_seq_len:
raise ValueError(
f'Cannot forward input with past sequence length {past_position} and current sequence length '
f'{S + 1}, this model only supports total sequence length <= {self.cfg.max_seq_len}.'
)
pos = torch.arange(past_position,
S + past_position,
dtype=torch.long,
device=input_ids.device).unsqueeze(0)
pos_emb = self.transformer.wpe(pos) # type: ignore
x = tok_emb + pos_emb
Expand All @@ -158,11 +178,18 @@ def forward(self,
x = self.transformer.emb_drop(x_shrunk)

attn_bias = self._attn_bias(device=x.device, dtype=x.dtype)
for block in self.transformer.blocks: # type: ignore
x = block(x,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=self.is_causal)

for b_idx, block in enumerate(self.transformer.blocks): # type: ignore
past_key_value = past_key_values[
b_idx] if past_key_values is not None else None
x, past_key_value = block(x,
past_key_value=past_key_value,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=self.is_causal)
if past_key_values is not None:
past_key_values[b_idx] = past_key_value

x = self.transformer.ln_f(x) # type: ignore
# output embedding weight tied to input embedding
assert isinstance(self.transformer.wte, nn.Module) # pyright
Expand Down
35 changes: 19 additions & 16 deletions examples/llm/tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_attn_impl(attn_impl_0,

key_padding_mask = torch.ones(n, s).to(device).bool()

def gen_bias(attn_impl, key_padding_mask):
def gen_bias(attn_impl):
causal = True
attn_bias = None
bs = attention.attn_bias_shape(attn_impl,
Expand All @@ -72,24 +72,26 @@ def gen_bias(attn_impl, key_padding_mask):
alibi=alibi,
alibi_bias_max=8)

return attn_bias, key_padding_mask
return attn_bias

x0 = torch.randn(n, s, f).to(device)
x1 = x0.clone().detach()
x0.requires_grad = True
x1.requires_grad = True

with torch.autocast(x0.device.type):
attn_bias, kpm = gen_bias(attn0.attn_impl, key_padding_mask)
y0, _ = attn0(x0,
attn_bias=attn_bias,
key_padding_mask=kpm,
is_causal=True)
attn_bias, kpm = gen_bias(attn1.attn_impl, key_padding_mask)
y1, _ = attn1(x1,
attn_bias=attn_bias,
key_padding_mask=kpm,
is_causal=True)
attn_bias = gen_bias(attn0.attn_impl)
y0, _, _ = attn0(x0,
past_key_value=None,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=True)
attn_bias = gen_bias(attn1.attn_impl)
y1, _, _ = attn1(x1,
past_key_value=None,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=True)
y0 *= key_padding_mask.unsqueeze(-1)
y1 *= key_padding_mask.unsqueeze(-1)

Expand Down Expand Up @@ -160,10 +162,11 @@ def gen_tca_mask():
x1.requires_grad = True

with torch.autocast(x0.device.type):
y0, _ = mmhsa(x0,
attn_bias=None,
key_padding_mask=key_padding_mask,
is_causal=True)
y0, _, _ = mmhsa(x0,
past_key_value=None,
attn_bias=None,
key_padding_mask=key_padding_mask,
is_causal=True)
y1, _ = tmhsa(x1,
x1,
x1,
Expand Down
11 changes: 6 additions & 5 deletions examples/llm/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,12 @@ def test_attention_mechanism(batch_size=2):

for block in model.model.transformer.blocks:
a = block.ln_1(x)
b, attention_weights = block.attn(a,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=model.model.is_causal,
needs_weights=True)
b, attention_weights, _ = block.attn(a,
past_key_value=None,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=model.model.is_causal,
needs_weights=True)

zerod_weights = (attention_weights == 0)
assert torch.equal(expected_zerod_weights.expand(*zerod_weights.shape),
Expand Down