Skip to content

Commit

Permalink
add llama3 rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 committed Oct 31, 2024
1 parent 6d843f4 commit 0417c56
Showing 1 changed file with 57 additions and 21 deletions.
78 changes: 57 additions & 21 deletions lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
from torch import nn

from ..default.rotary_embedding import (Llama3RotaryEmbeddingImpl,
LlamaDynamicNTKScalingRotaryEmbedding)
from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding
from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters,
RopeType, RotaryEmbeddingBuilder,
RotaryEmbeddingImpl, YarnParameters)
Expand All @@ -13,8 +14,7 @@ def _rotary_embedding_fwd(position_ids: torch.Tensor,
inv_freq: torch.Tensor,
scaling_factor: float,
mscale: float = None,
dtype: torch.dtype = None,
device_type: torch.device = None):
dtype: torch.dtype = None):
"""rotary embedding forward."""
if dtype is None:
dtype = torch.float16
Expand Down Expand Up @@ -61,15 +61,13 @@ def __init__(self,
def forward(self, x, position_ids):
"""forward."""
# x: [bs, num_attention_heads, seq_len, head_size]
device_type = x.device.type
dtype = x.dtype
if self.inv_freq.device != x.device:
self.inv_freq = self.inv_freq.to(x.device)
return _rotary_embedding_fwd(position_ids,
self.inv_freq,
scaling_factor=self.scaling_factor,
dtype=dtype,
device_type=device_type)
dtype=dtype)


class DlinferLlamaDynamicNTKScalingRotaryEmbedding(
Expand All @@ -85,21 +83,22 @@ def __init__(self,
scaling_factor: float = 1.0,
max_position_embeddings: int = 2048):
super().__init__(dim, base, scaling_factor, max_position_embeddings)
self.exponent_1 = self.dim / (self.dim - 2)
self.exponent_2 = torch.arange(
self.dim_scale_ratio = self.dim / (self.dim - 2)
self.pos_freq_scaling = torch.arange(
0, self.dim, 2, dtype=torch.int64).float().cuda() / self.dim
self.sub = self.scaling_factor - 1
self.div = self.scaling_factor / self.max_position_embeddings
self.scale_offset = self.scaling_factor - 1
self.pos_scale_factor = self.scaling_factor / \
self.max_position_embeddings

def _ntk_inv_freq(self, seq_len: torch.Tensor):
"""ntk_inv_freq."""
base = self.base * ((self.div * seq_len) - self.sub)**self.exponent_1
inv_freq = 1.0 / (base**self.exponent_2)
"""Calculate inverse frequency with NTK scaling."""
base = self.base * ((self.pos_scale_factor * seq_len) -
self.scale_offset)**self.dim_scale_ratio
inv_freq = 1.0 / (base**self.pos_freq_scaling)
return inv_freq

def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
"""forward."""
device_type = x.device.type
dtype = x.dtype
seq_len = torch.max(position_ids) + 1
ntk_inv_freq = self._ntk_inv_freq(seq_len)
Expand All @@ -111,11 +110,49 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
cos, sin = _rotary_embedding_fwd(position_ids,
inv_freq,
scaling_factor=1.0,
dtype=dtype,
device_type=device_type)
dtype=dtype)
return cos, sin


class DlinferLlama3RotaryEmbeddingImpl(DlinferRotaryEmbeddingImpl):
"""llama3 rotary embedding implementation."""

def __init__(
self,
dim: int,
base: int = 10000,
scaling_factor: float = 1.0,
low_freq_factor: float = 1.0,
high_freq_factor: float = 4.0,
original_max_position_embeddings: int = 8194,
):
super().__init__(dim, base, scaling_factor)
old_context_len = original_max_position_embeddings
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor

inv_freq = self.inv_freq
factor = self.scaling_factor

wavelen = 2 * math.pi / inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen,
inv_freq / factor, inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen >
low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq,
inv_freq_llama)
self.scaling_factor = 1.0
self.register_buffer('inv_freq', inv_freq_llama)


class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):
"""rotary embedding builder."""

Expand All @@ -137,10 +174,9 @@ def build(
return DlinferLlamaDynamicNTKScalingRotaryEmbedding(
dim, base, scaling_factor, max_position_embeddings)
elif emb_type == RopeType.Llama3:
return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor,
llama3_params.low_freq_factor,
llama3_params.high_freq_factor,
max_position_embeddings)
return DlinferLlama3RotaryEmbeddingImpl(
dim, base, scaling_factor, llama3_params.low_freq_factor,
llama3_params.high_freq_factor, max_position_embeddings)
else:
raise NotImplementedError(
f'Unsupported embedding type: {emb_type}')

0 comments on commit 0417c56

Please sign in to comment.