Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] support multi query attention && multi goruped #2403

Merged
merged 4 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 91 additions & 19 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Multi-Head Attention layer definition."""

import math
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch import nn
Expand All @@ -26,6 +26,14 @@

class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
if n_kv_head != None and n_kv_head != n_head
see: https://arxiv.org/pdf/1911.02150.pdf
https://arxiv.org/pdf/2305.13245.pdf
Mddct marked this conversation as resolved.
Show resolved Hide resolved

Example:
case 1: n_kv_head == None, head_dim == None, MultiHead attention (MHSA)
case 2: n_kv_head=1, n_head = 16, MultiQuery attention (MQA)
case 3: nv_kv_head=2, n_head = 16, GroupedQuery attention (GQA)

Args:
n_head (int): The number of heads.
Expand All @@ -41,17 +49,30 @@ def __init__(self,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0

self.inner_dim = n_feat if head_dim is None else head_dim * n_head
if n_kv_head is not None:
assert head_dim is not None
self.inner_kv_dim = head_dim * n_kv_head
n_kv_head = n_kv_head
else:
self.inner_kv_dim = self.inner_dim
n_kv_head = n_head
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
assert self.d_k == self.inner_kv_dim // n_kv_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat, bias=query_bias)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat, bias=value_bias)
self.linear_out = nn.Linear(n_feat, n_feat)
self.h_kv = n_kv_head

self.linear_q = nn.Linear(n_feat, self.inner_dim, bias=query_bias)
self.linear_k = nn.Linear(n_feat, self.inner_kv_dim, bias=key_bias)
self.linear_v = nn.Linear(n_feat, self.inner_kv_dim, bias=value_bias)
self.linear_out = nn.Linear(self.inner_dim, n_feat, bias=query_bias)
self.dropout = nn.Dropout(p=dropout_rate)

self.use_sdpa = use_sdpa
Expand All @@ -61,16 +82,21 @@ def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor:
assert x.ndim >= 3
if name == 'query':
x = self.linear_q(x)
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h, self.d_k])
elif name == 'key':
x = self.linear_k(x)
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h_kv, self.d_k])
else:
assert name == 'value'
x = self.linear_v(x)
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h_kv, self.d_k])

# split last dim
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h, self.d_k])
x = x.view(x_shape)
x = x.transpose(-3, -2) # (batch, ..., head, time, d_k)
x = x.transpose(-3, -2) # (batch, ..., head or head_kv, time, d_k)
return x

def forward_qkv(
Expand All @@ -87,9 +113,9 @@ def forward_qkv(
torch.Tensor: Transformed query tensor, size
(#batch, ..., n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, ..., n_head, time2, d_k).
(#batch, ..., n_head_kv, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, ..., n_head, time2, d_k).
(#batch, ..., n_head_kv, time2, d_k).

"""
q = self._forward_linearx('query', query)
Expand Down Expand Up @@ -210,6 +236,19 @@ def forward(
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)

# for multi query or multi group attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
Expand Down Expand Up @@ -244,10 +283,12 @@ def __init__(self,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa)
value_bias, use_sdpa, n_kv_head, head_dim)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
Expand Down Expand Up @@ -335,10 +376,24 @@ def forward(
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)

# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)

# for multi query or multi groups attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
Expand Down Expand Up @@ -395,9 +450,11 @@ def __init__(self,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa)
value_bias, use_sdpa, n_kv_head, head_dim)

def forward(
self,
Expand All @@ -418,6 +475,19 @@ def forward(
q, k, v = self.forward_qkv(query, key, value)
new_cache = torch.cat((k, v), dim=-1)

# for multi query or multi groups attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

B = query.size(0)
Beams = 1
if B != k.size(0):
Expand Down Expand Up @@ -464,10 +534,12 @@ def __init__(self,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):

use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
del n_kv_head, head_dim
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa)
value_bias, use_sdpa, None, None)
# TODO(Mddct): 64 8 1 as args
self.max_right_rel_pos = 64
self.max_left_rel_pos = 8
Expand Down
14 changes: 11 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
super().__init__()
attention_dim = encoder_output_size
Expand Down Expand Up @@ -114,11 +116,11 @@ def __init__(
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim,
self_attention_dropout_rate, query_bias, key_bias,
value_bias, use_sdpa),
value_bias, use_sdpa, n_kv_head, head_dim),
WENET_ATTENTION_CLASSES["crossattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
query_bias, key_bias, value_bias, use_sdpa)
if src_attention else None,
query_bias, key_bias, value_bias, use_sdpa, n_kv_head,
head_dim) if src_attention else None,
mlp_class(attention_dim, linear_units, dropout_rate,
activation, mlp_bias),
dropout_rate,
Expand Down Expand Up @@ -334,6 +336,8 @@ def __init__(
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):

super().__init__()
Expand All @@ -360,6 +364,8 @@ def __init__(
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)

self.right_decoder = TransformerDecoder(
Expand All @@ -384,6 +390,8 @@ def __init__(
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)

def forward(
Expand Down
11 changes: 9 additions & 2 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from typing import Tuple
from typing import Optional, Tuple

import torch
import torch.utils.checkpoint as ckpt
Expand Down Expand Up @@ -375,6 +375,8 @@ def __init__(
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
""" Construct TransformerEncoder

Expand All @@ -396,7 +398,8 @@ def __init__(
output_size,
attention_dropout_rate,
query_bias, key_bias,
value_bias, use_sdpa),
value_bias, use_sdpa,
n_kv_head, head_dim),
mlp_class(output_size, linear_units, dropout_rate, activation,
mlp_bias),
dropout_rate,
Expand Down Expand Up @@ -445,6 +448,8 @@ def __init__(
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
"""Construct ConformerEncoder

Expand Down Expand Up @@ -481,6 +486,8 @@ def __init__(
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
)
# feed-forward module definition
positionwise_layer_args = (
Expand Down
Loading