Skip to content

[DeepSeek-V3] Different rotary embedding implementation between DeepSeek-AI and Transformers #39687

@wwwjn

Description

@wwwjn

Counterpart issue : deepseek-ai/DeepSeek-V3#938

Describe the issue
Hi team, I'm working on reproducing the great deepseek-v3 model on torchtitan . While I'm trying to run numerical verification, I noticed the rotary embedding in HF and this repo is different.

HF: https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L339

  • In HF rotary embedding implementation, they explicitly permute (interleave the odd column and even columns) the q_pe / k_pe.

Deepseek-AI: https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L339

  • In DeepSeek-AI's implementation, they didn't permute the q_pe / k_pe.
  • In the convert.py, the script load HF weights as it's original order, and this script didn't permute the weights, to accommodate the ordering difference in apply_rotary_embedding().

And this discrepancy will result in different mathematical results after the attention module of the first dense layer.

I want to double-check with the team if I missed something here. Thank you for your help in advance! cc @tianyu-l

To Reproduce
Environment: transformers ==4.54.0

I run following 2 runs with code: https://github.com/wwwjn/DeepSeek-V3 . I randomized the same inputs for both runs.

  1. a single forward pass using HuggingFace transformers, with weights from:https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/tree/main .
    python hf_implementation/hf_implementation.py --num_layers 5 > hf_outputs.txt 2>&1

  2. Using this repo's model implementation and run a single forward pass (To simplify, I didn't use the - First, I used the convert.py to convert HF checkpoint weights.
    python convert.py --hf-ckpt-path /data/users/jianiw/dsv3-weights/ --save-path /data/users/jianiw/dsv3-weights-5-layer/ --n-experts 256 --model-parallel 8

    • Second, I run a single forward pass using
      torchrun --nnodes 1 --nproc-per-node 8 inference/run_single_forward.py --config inference/configs/config_671B.json > dsv3-output.txt 2>&1

Here's the detailed numerical comparison I've seen:

Image Image

Expected behavior
Expected behavior: After first dense layer's attention layer, the output should be almost the same (There might be slightly difference because of fp8 vs other pericision).

Additional context
We observed the same issue for llama3 model before #30872, and get better understanding with @ArthurZucker 's help - The weights for llama3 on HuggingFace is permuted compared to original Meta's weights . so we need to manually permute the weights back to accommodate the rotary embedding implementation difference. Reference: pytorch/torchtitan#335, pytorch/torchtitan#1291 (comment)

Metadata

Metadata

Assignees

No one assigned

    Labels

    ROPEAny issues or PR related to the trickyness of ROPE

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions