Skip to content

Commit

Permalink
[transformer] add rope for transformer/conformer (#2458)
Browse files Browse the repository at this point in the history
* [transformer] add rope for transformers/conformer

* it works

* rm comment

* fix typo

* fix init

* add assert attn type for transformer

* fix pos_emb in transformer

* llama rope and google rope work!
  • Loading branch information
Mddct authored Apr 4, 2024
1 parent de21463 commit 4d12918
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 14 deletions.
100 changes: 100 additions & 0 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch
from torch import nn

from wenet.utils.rope_utils import llama_apply_rotary_emb


class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Expand Down Expand Up @@ -600,3 +602,101 @@ def forward(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache


class RopeMultiHeadedAttention(MultiHeadedAttention):

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa, n_kv_head, head_dim)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute rope scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
Wenet.
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# NOTE(Mddct): In order to make the code easier to read,
# these two lines are not placed in MultiHeadedAttention.
q = llama_apply_rotary_emb(q, pos_emb)
k = llama_apply_rotary_emb(k, pos_emb)
# see above
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1)

if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=1,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=1,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
else:
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask.unsqueeze(1),
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = (output.transpose(1, 2).contiguous().view(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache
40 changes: 40 additions & 0 deletions wenet/transformer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch.nn.functional as F
import numpy as np

from wenet.utils.rope_utils import precompute_freqs_cis


class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
Expand Down Expand Up @@ -194,3 +196,41 @@ def forward(self,
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return torch.zeros(1, size, self.d_model)


class RopePositionalEncoding(PositionalEncoding):

def __init__(self,
d_model: int,
head_dim: int,
dropout_rate: float,
max_len: int = 1500,
rope_theta=10000.0):
super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len)
delattr(self, 'pe')

pe = precompute_freqs_cis(head_dim, max_len * 2, rope_theta)
self.register_buffer("pe", pe.unsqueeze(0))
self.dropout_rate = dropout_rate

def forward(
self,
x: torch.Tensor,
offset: Union[int,
torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]:

pos_emb = self.position_encoding(offset, x.size(1), False)
pos_emb = pos_emb.unsqueeze(1) # [1, 1, seq, head_dim//2]
# NOTE(Mddct): some model don't scale
# TODO(Mddct): fix
x = x * self.xscale
# NOTE(Mddct) dropout don't suuport complex float for pos_emb
return self.dropout(x), self.dropout_complex(pos_emb)

def dropout_complex(self, x):
mask = torch.nn.functional.dropout(
torch.ones_like(x.real),
training=self.training,
p=self.dropout_rate,
)
return x * mask
28 changes: 16 additions & 12 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,16 @@ def __init__(
self._output_size = output_size

self.global_cmvn = global_cmvn
pos_emb_class = WENET_EMB_CLASSES[pos_enc_layer_type]
# NOTE(Mddct): head_dim == output_size // attention_heads for most of
# speech tasks, but for other task (LLM),
# head_dim == hidden_size * attention_heads. refactor later
self.embed = WENET_SUBSAMPLE_CLASSES[input_layer](
input_size,
output_size,
dropout_rate,
WENET_EMB_CLASSES[pos_enc_layer_type](output_size,
positional_dropout_rate),
)
input_size, output_size, dropout_rate,
pos_emb_class(output_size, positional_dropout_rate)
if pos_enc_layer_type != 'rope_pos' else pos_emb_class(
output_size, output_size //
attention_heads, positional_dropout_rate))

assert layer_norm_type in ['layer_norm', 'rms_norm']
self.normalize_before = normalize_before
Expand Down Expand Up @@ -377,6 +380,7 @@ def __init__(
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
selfattention_layer_type: str = "selfattn",
):
""" Construct TransformerEncoder
Expand All @@ -389,17 +393,17 @@ def __init__(
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps)

assert selfattention_layer_type in ['selfattn', 'rope_abs_selfattn']
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
query_bias, key_bias,
value_bias, use_sdpa,
n_kv_head, head_dim),
WENET_ATTENTION_CLASSES[selfattention_layer_type](
attention_heads, output_size, attention_dropout_rate,
query_bias, key_bias, value_bias, use_sdpa, n_kv_head,
head_dim),
mlp_class(output_size, linear_units, dropout_rate, activation,
mlp_bias),
dropout_rate,
Expand Down
7 changes: 6 additions & 1 deletion wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ def forward(
residual = x
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
x_att, new_att_cache = self.self_attn(x,
x,
x,
mask,
pos_emb,
cache=att_cache)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
Expand Down
6 changes: 5 additions & 1 deletion wenet/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4
from wenet.transformer.embedding import (PositionalEncoding,
RelPositionalEncoding,
RopePositionalEncoding,
WhisperPositionalEncoding,
LearnablePositionalEncoding,
NoPositionalEncoding)
from wenet.transformer.attention import (MultiHeadedAttention,
MultiHeadedCrossAttention,
RelPositionMultiHeadedAttention,
RopeMultiHeadedAttention,
ShawRelPositionMultiHeadedAttention)
from wenet.efficient_conformer.attention import (
GroupedRelPositionMultiHeadedAttention)
Expand Down Expand Up @@ -68,14 +70,16 @@
"abs_pos_whisper": WhisperPositionalEncoding,
"embed_learnable_pe": LearnablePositionalEncoding,
"abs_pos_paraformer": ParaformerPositinoalEncoding,
'rope_pos': RopePositionalEncoding,
}

WENET_ATTENTION_CLASSES = {
"selfattn": MultiHeadedAttention,
"rel_selfattn": RelPositionMultiHeadedAttention,
"grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention,
"crossattn": MultiHeadedCrossAttention,
'shaw_rel_selfattn': ShawRelPositionMultiHeadedAttention
'shaw_rel_selfattn': ShawRelPositionMultiHeadedAttention,
'rope_abs_selfattn': RopeMultiHeadedAttention,
}

WENET_MLP_CLASSES = {
Expand Down
45 changes: 45 additions & 0 deletions wenet/utils/rope_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch


# copy from:https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L84
def precompute_freqs_cis(dim: int,
end: int,
theta: float = 10000.0) -> torch.Tensor:
"""Precomputes the frequency cis."""
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis


# modified from:
# https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L95
def google_apply_rotary_emb(x: torch.Tensor,
freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1)
return x_out


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape[2:] == (x.shape[1], x.shape[-1])
# 2 is seq_len in wenet
shape = [
d if i == 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
]
return freqs_cis.view(*shape)


def llama_apply_rotary_emb(x: torch.Tensor,
freqs_cis: torch.Tensor) -> torch.Tensor:
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, x_)
x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
return x_out.type_as(x)

0 comments on commit 4d12918

Please sign in to comment.