Skip to content

Commit

Permalink
add kv_cache to LLM (#244)
Browse files Browse the repository at this point in the history
* init kv_cache pr

* pass past_key_values thru model fwd

* add qpadmask

* enable key_padding_mask for past_key_value

* use valid pos emb if using learned pos emb

* lint

* only compute past_position when needed

* nit

* Update examples/llm/src/models/layers/gpt_blocks.py

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* fix suggested edit

* dk pr cmt

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
  • Loading branch information
vchiley and dakinggg authored Mar 21, 2023
1 parent 2b09481 commit 83e6998
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 49 deletions.
70 changes: 59 additions & 11 deletions examples/llm/src/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@ def scaled_multihead_dot_product_attention(
n_heads,
softmax_scale=None,
attn_bias=None,
query_padding_mask=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
training=False,
needs_weights=False,
):
if query_padding_mask is not None:
query = query.masked_fill(~query_padding_mask.unsqueeze(-1), 0)
if key_padding_mask is not None:
key = key.masked_fill(~key_padding_mask.unsqueeze(-1), 0)
value = value.masked_fill(~key_padding_mask.unsqueeze(-1), 0)

q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
k = rearrange(key, 'b s (h d) -> b h d s', h=n_heads) # includes key.t()
v = rearrange(value, 'b s (h d) -> b h s d', h=n_heads)
Expand All @@ -50,6 +57,9 @@ def scaled_multihead_dot_product_attention(
)
attn_weight = attn_weight + attn_bias

if query_padding_mask is not None:
attn_weight = attn_weight.masked_fill(
~query_padding_mask.view(b, 1, s_q, 1), -float('inf'))
if key_padding_mask is not None:
attn_weight = attn_weight.masked_fill(
~key_padding_mask.view((b, 1, 1, s_k)), -float('inf'))
Expand All @@ -65,14 +75,21 @@ def scaled_multihead_dot_product_attention(

attn_weight = torch.softmax(attn_weight, dim=-1)

if query_padding_mask is not None:
attn_weight = attn_weight.masked_fill(
~query_padding_mask.view(b, 1, s_q, 1), 0)
if key_padding_mask is not None:
attn_weight = attn_weight.masked_fill(
~key_padding_mask.view(b, 1, 1, s_k), 0)

if dropout_p:
attn_weight = torch.nn.functional.dropout(attn_weight,
p=dropout_p,
training=training,
inplace=True)

out = attn_weight.matmul(v)
out = rearrange(out, 'b h s d -> b s (h d)', h=n_heads)
out = rearrange(out, 'b h s d -> b s (h d)')

if needs_weights:
return out, attn_weight
Expand All @@ -94,6 +111,7 @@ def flash_attn_fn(
n_heads,
softmax_scale=None,
attn_bias=None,
query_padding_mask=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
Expand All @@ -113,15 +131,13 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

if query_padding_mask is None:
query_padding_mask = torch.ones_like(query[:, :, 0], dtype=torch.bool)
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)

if training:
pad_mask = key_padding_mask
else:
pad_mask = torch.ones_like(query[:, :, 0], dtype=torch.bool)
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query, pad_mask)
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
Expand Down Expand Up @@ -158,6 +174,7 @@ def triton_flash_attn_fn(
n_heads,
softmax_scale=None,
attn_bias=None,
query_padding_mask=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
Expand All @@ -179,16 +196,31 @@ def triton_flash_attn_fn(
raise NotImplementedError(
f'attn_impl: triton cannot return attn weights.')

if query_padding_mask is not None:
query = query.masked_fill(~query_padding_mask.unsqueeze(-1), 0)
if key_padding_mask is not None:
b_size, s_k = key_padding_mask.shape
key = key.masked_fill(~key_padding_mask.unsqueeze(-1), 0)
value = value.masked_fill(~key_padding_mask.unsqueeze(-1), 0)

if query_padding_mask is not None or key_padding_mask is not None:
b_size, s_q, s_k = query.size(0), 1, 1
if query_padding_mask is not None:
s_q = query_padding_mask.size(1)
if key_padding_mask is not None:
s_k = key_padding_mask.size(1)

if attn_bias is not None:
attn_bias = attn_bias.expand(b_size, -1, -1, -1)
else:
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
attn_bias = query.new_zeros(b_size, 1, s_q, s_k)

attn_bias = attn_bias.masked_fill(
~key_padding_mask.view((b_size, 1, 1, s_k)), -float('inf'))
if query_padding_mask is not None:
attn_bias = attn_bias.masked_fill(
~query_padding_mask.view((b_size, 1, s_q, 1)), -float('inf'))

if key_padding_mask is not None:
attn_bias = attn_bias.masked_fill(
~key_padding_mask.view((b_size, 1, 1, s_k)), -float('inf'))

query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
Expand All @@ -199,6 +231,9 @@ def triton_flash_attn_fn(
softmax_scale)

output = attn_output.view(*attn_output.shape[:2], -1)
if query_padding_mask is not None:
output = output.masked_fill(~query_padding_mask.unsqueeze(-1), 0)

return output, None


Expand Down Expand Up @@ -256,6 +291,7 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):

def forward(self,
x,
past_key_value=None,
attn_bias=None,
key_padding_mask=None,
is_causal=True,
Expand All @@ -266,12 +302,23 @@ def forward(self,

query, key, value = qkv.chunk(3, dim=2)

query_padding_mask = None
if key_padding_mask is not None:
query_padding_mask = key_padding_mask[:, -query.size(1):]

if self.attn_qk_ln:
# Applying layernorm to qk
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)

if past_key_value is not None:
if len(past_key_value) == 0:
key = torch.cat([past_key_value[0], key], dim=1)
value = torch.cat([past_key_value[1], value], dim=1)

past_key_value = (key, value)

if attn_bias is not None:
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]

Expand All @@ -282,14 +329,15 @@ def forward(self,
self.n_heads,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
query_padding_mask=query_padding_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
)

return self.out_proj(context), attn_weights
return self.out_proj(context), attn_weights, past_key_value


def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, causal):
Expand Down
16 changes: 9 additions & 7 deletions examples/llm/src/models/layers/gpt_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""GPT Blocks used for the GPT Model."""

from typing import Optional
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,17 +48,19 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):
def forward(
self,
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
a = self.ln_1(x)
b, _ = self.attn(a,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal)
b, _, past_key_value = self.attn(a,
past_key_value=past_key_value,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal)
x = x + self.resid_attn_dropout(b)
m = self.ln_2(x)
n = self.mlp(m)
x = x + self.resid_mlp_dropout(n)
return x
return x, past_key_value
47 changes: 37 additions & 10 deletions examples/llm/src/models/mosaic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import math
import warnings
from typing import Optional
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -131,9 +131,11 @@ def _attn_bias(self, device, dtype):

return self.attn_bias

def forward(self,
input_ids: torch.LongTensor,
key_padding_mask: Optional[torch.ByteTensor] = None):
def forward(
self,
input_ids: torch.LongTensor,
key_padding_mask: Optional[torch.ByteTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None):
S = input_ids.size(1)
assert (
S <= self.cfg.max_seq_len
Expand All @@ -143,7 +145,25 @@ def forward(self,
if self.alibi:
x = tok_emb
else:
pos = torch.arange(0, S, dtype=torch.long,
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.cfg.n_layers:
raise ValueError(
f'past_key_values must provide a past_key_value for each attention ' +\
f'layer in the network ({len(past_key_values)=}; {self.cfg.n_layers=}).'
)
# get the key tensor whose spec should be (batch, seq, dim), and
# collect the `seq`, so that the position embedding is shifted
past_position = past_key_values[0][0].size(1)

if S + past_position > self.cfg.max_seq_len:
raise ValueError(
f'Cannot forward input with past sequence length {past_position} and current sequence length '
f'{S + 1}, this model only supports total sequence length <= {self.cfg.max_seq_len}.'
)
pos = torch.arange(past_position,
S + past_position,
dtype=torch.long,
device=input_ids.device).unsqueeze(0)
pos_emb = self.transformer.wpe(pos) # type: ignore
x = tok_emb + pos_emb
Expand All @@ -158,11 +178,18 @@ def forward(self,
x = self.transformer.emb_drop(x_shrunk)

attn_bias = self._attn_bias(device=x.device, dtype=x.dtype)
for block in self.transformer.blocks: # type: ignore
x = block(x,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=self.is_causal)

for b_idx, block in enumerate(self.transformer.blocks): # type: ignore
past_key_value = past_key_values[
b_idx] if past_key_values is not None else None
x, past_key_value = block(x,
past_key_value=past_key_value,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=self.is_causal)
if past_key_values is not None:
past_key_values[b_idx] = past_key_value

x = self.transformer.ln_f(x) # type: ignore
# output embedding weight tied to input embedding
assert isinstance(self.transformer.wte, nn.Module) # pyright
Expand Down
35 changes: 19 additions & 16 deletions examples/llm/tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_attn_impl(attn_impl_0,

key_padding_mask = torch.ones(n, s).to(device).bool()

def gen_bias(attn_impl, key_padding_mask):
def gen_bias(attn_impl):
causal = True
attn_bias = None
bs = attention.attn_bias_shape(attn_impl,
Expand All @@ -72,24 +72,26 @@ def gen_bias(attn_impl, key_padding_mask):
alibi=alibi,
alibi_bias_max=8)

return attn_bias, key_padding_mask
return attn_bias

x0 = torch.randn(n, s, f).to(device)
x1 = x0.clone().detach()
x0.requires_grad = True
x1.requires_grad = True

with torch.autocast(x0.device.type):
attn_bias, kpm = gen_bias(attn0.attn_impl, key_padding_mask)
y0, _ = attn0(x0,
attn_bias=attn_bias,
key_padding_mask=kpm,
is_causal=True)
attn_bias, kpm = gen_bias(attn1.attn_impl, key_padding_mask)
y1, _ = attn1(x1,
attn_bias=attn_bias,
key_padding_mask=kpm,
is_causal=True)
attn_bias = gen_bias(attn0.attn_impl)
y0, _, _ = attn0(x0,
past_key_value=None,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=True)
attn_bias = gen_bias(attn1.attn_impl)
y1, _, _ = attn1(x1,
past_key_value=None,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=True)
y0 *= key_padding_mask.unsqueeze(-1)
y1 *= key_padding_mask.unsqueeze(-1)

Expand Down Expand Up @@ -160,10 +162,11 @@ def gen_tca_mask():
x1.requires_grad = True

with torch.autocast(x0.device.type):
y0, _ = mmhsa(x0,
attn_bias=None,
key_padding_mask=key_padding_mask,
is_causal=True)
y0, _, _ = mmhsa(x0,
past_key_value=None,
attn_bias=None,
key_padding_mask=key_padding_mask,
is_causal=True)
y1, _ = tmhsa(x1,
x1,
x1,
Expand Down
11 changes: 6 additions & 5 deletions examples/llm/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,12 @@ def test_attention_mechanism(batch_size=2):

for block in model.model.transformer.blocks:
a = block.ln_1(x)
b, attention_weights = block.attn(a,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=model.model.is_causal,
needs_weights=True)
b, attention_weights, _ = block.attn(a,
past_key_value=None,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=model.model.is_causal,
needs_weights=True)

zerod_weights = (attention_weights == 0)
assert torch.equal(expected_zerod_weights.expand(*zerod_weights.shape),
Expand Down

0 comments on commit 83e6998

Please sign in to comment.