-
Notifications
You must be signed in to change notification settings - Fork 461
Description
I've been working with the pretrained Llama 3 weights, and found out that the RoPE implementation here does not match the one found in other places. The difference is whether you treat sequential entries of the embeddings as (real, imaginary), or you treat the first half as real, and the second half as imaginary.
The current torchtitan implementation uses the former, while both Transformers and llama.cpp for example use the latter.
This also seems to mean that loading weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B does not work. I've verified numerically that you need to use the latter RoPE implementation to get correct results with existing weights. I'm slightly worried that I'm doing something wrong, but perhaps someone else can verify? I can post some code if that helps.
Here's a small change to apply_rotary_emb
which can be used to make it match the cos/sin implementation numerically.
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
# first half is real, second half is imaginary
xq_ = torch.complex(xq[..., :xq.shape[-1] // 2].float(), xq[..., xq.shape[-1] // 2:].float())
xk_ = torch.complex(xk[..., :xk.shape[-1] // 2].float(), xk[..., xk.shape[-1] // 2:].float())
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
# added this
xq_out = torch.cat([xq_out[..., ::2], xq_out[..., 1::2]], dim=-1)
xk_out = torch.cat([xk_out[..., ::2], xk_out[..., 1::2]], dim=-1)
return xq_out.type_as(xq), xk_out.type_as(xk)