Skip to content

llama3 rope #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Dec 9, 2024
Merged

llama3 rope #55

merged 22 commits into from
Dec 9, 2024

Conversation

RaymondLi0
Copy link
Contributor

@RaymondLi0 RaymondLi0 commented Nov 20, 2024

✨ Description

Closes #39

πŸ” Type of change

Select all that apply:

  • πŸ› Bug fix (non-breaking change that addresses a specific issue)
  • πŸš€ New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • πŸ“ˆ Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • πŸ› οΈ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • πŸ“¦ Dependency bump (updates dependencies, including Dockerfile or package changes)
  • πŸ“ Documentation change (updates documentation, including new content or typo fixes)
  • πŸ”§ Infrastructure/Build change (affects build process, CI/CD, or dependencies)

πŸ“ Changes

  • add llama3-style rope scaling
  • add llama3 test config (run with MODEL=llama3 pytest ./tests/test_checkpoint.py)
  • handle nested HF configs in conversion

Testing

MODEL=llama3 pytest ./tests/test_checkpoint.py passes.
When "sabotaging" the conversion here

def export_rotary_scaling_type(fast_llm_value: RotaryScalingType) -> dict[str, typing.Any] | None:
match fast_llm_value:
case RotaryScalingType.none:
return "default"
case RotaryScalingType.llama3:
return "llama3"
case _:
raise ValueError(f"Unsupported rotary scaling type: {fast_llm_value}")
def import_rotary_scaling_type(export_value):
if export_value is None:
return RotaryScalingType.none
match export_value:
case "default":
return RotaryScalingType.none
case "llama3":
return RotaryScalingType.llama3
case _:
raise ValueError(f"Unsupported rotary scaling type: {export_value}")

with

def export_rotary_scaling_type(fast_llm_value: RotaryScalingType) -> dict[str, typing.Any] | None:
    return None


def import_rotary_scaling_type(export_value):
    return RotaryScalingType.llama3

then test_run_converted_model fails, showing that model outputs differ when using an incorrect rotary config.

@tscholak
Copy link
Collaborator

Hi @RaymondLi0! Functionally this looks like what we want (pending model conversion), but are you confident (i.e. have you checked) that the forward and backward passes of hf-llama and fast-llm-llama are the same?

Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't review for correctness, but looks ok. I'll leave the approval to @tscholak

@@ -127,6 +127,11 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig):
desc="Scale for the rotary positional embeddings. Default: -math.log(10000) = -9.210",
hint=FieldHint.feature,
)
rotary_scaling_type: RotaryScalingType = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about rotary_embedding_type so we make it more general?
Or better, merge with use_rotary_embeddings into rotary_embeddings: RotaryEmbeddingType | None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would make sense.
Although that would make this a breaking change right? Would that be OK?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tscholak @jlamypoirier I could merge into rotary_embeddings: RotaryEmbeddingType | None, and keep use_rotary_embeddings as a deprecated argument to maintain backwards compatibility. wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, in addition to breaking existing checkpoints it will make conversion more complex than needed. Let's postpone to another PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! We only want to unblock model conversion at this point. Thanks!

@RaymondLi0 RaymondLi0 changed the title Raymond/llama3 rope llama3 rope Nov 20, 2024
@RaymondLi0
Copy link
Contributor Author

RaymondLi0 commented Nov 21, 2024

Hi @RaymondLi0! Functionally this looks like what we want (pending model conversion), but are you confident (i.e. have you checked) that the forward and backward passes of hf-llama and fast-llm-llama are the same?

Haven't done that check. Do we have existing tests comparing forward/backward of fast-llm and hf-transformers? If no I can look into adding this

@jlamypoirier
Copy link
Collaborator

jlamypoirier commented Nov 21, 2024

Haven't done that check. Do we have existing tests comparing forward/backward of fast-llm and hf-transformers? If no I can look into adding this

There is one in test_checkpoint https://github.com/ServiceNow/Fast-LLM/blob/main/tests/test_checkpoint.py#L31. It could work for this one if we added a llama3 model to the testing suite (in common.py). (It would also be a good test of conversion, etc.)

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @RaymondLi0!

@tscholak
Copy link
Collaborator

tscholak commented Dec 4, 2024

@RaymondLi0 @jlamypoirier merge?

@jlamypoirier
Copy link
Collaborator

Looks ok, I'll do some adjustments and then merge.

@jlamypoirier
Copy link
Collaborator

jlamypoirier commented Dec 4, 2024

Main changes I did:

  • Adjust names (simplify, avoid abbreviations)
  • Moved nested dict functions to utils since they are useful elsewhere, generalized them a bit.
  • Moved rotary preprocessing to layers module so we can use the rotary config instead of a ton of arguments
  • Replaced rotary scaling type with rotary type, also serves as rotary enabling parameter
  • Replaced rotary log scale with theta (wanted to do for a while, now is a good time)
  • Moved remaining rotary parameters to new rotary config where they belong
  • Added backward compatibility for the modified config parameters

Still need to run tests, @RaymondLi0 @tscholak feel free to have a look and comment.

Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got the tests to pass for both default and llama3 models. Not sure if that's enough? (Could try loading an old checkpoint to be sure)

@tscholak
Copy link
Collaborator

tscholak commented Dec 5, 2024

I'll continue the training of Llama 3.1 8B with this PR as a further test.

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs revision, because it doesn't work correctly.

I am currently cpt-ing Llama 3.1 8B on LambdaLabs, and I see this in config.yaml in the output folder:

model:
  base_model:
    cross_entropy_impl: fused
    init_method_std_embed: 0.015625
    tie_word_embeddings: false
    transformer:
      activation_type: silu
      add_linear_biases: false
      ffn_hidden_size: 14336
      gated: true
      head_groups: 8
      hidden_size: 4096
      init_method_std: 0.015625
      init_method_std_attn_proj: 0.001953125
      init_method_std_mlp_1: 0.015625
      init_method_std_mlp_2: 0.001953125
      init_method_std_qkv: 0.015625
      kv_channels: 128
      mlp_lr_scale:
      - null
      normalization:
        type: rms_norm
      num_attention_heads: 32
      num_layers: 32
      rotary:
        theta: 500000.0
        type: default
    use_position_embeddings: false
    vocab_size: 128256

Instead of type: default in the model.base_model.transformer.rotary, I expected to see llama3.

Please fix @RaymondLi0 @jlamypoirier

@RaymondLi0
Copy link
Contributor Author

Hmm this used to work, let me have a look

@RaymondLi0
Copy link
Contributor Author

@jlamypoirier The logic here

if converter.export_name is None or converter.export_name not in config
does not work because converter.export_name can be a tuple[str, ...].
That's why I had moved it to get_nested_dict_value in fc9ef13
The result is that currently, value is set to None for when importing a nested dict value like rope_scaling.rope_type

Was there a specific reason to remove that logic from get_nested_dict_value ?

@RaymondLi0
Copy link
Contributor Author

I fixed the config comparison tool that should have caught this error in the tests. But another error appeared in the tests, having a look

@RaymondLi0
Copy link
Contributor Author

There remains 2 points:

  • an error at
    _compare_configs(config, model._base_model_config)
    because model._base_model_config contains (validated?) default values while config doesn't.
  • address the following:

@jlamypoirier The logic here

if converter.export_name is None or converter.export_name not in config

does not work because converter.export_name can be a tuple[str, ...].
That's why I had moved it to get_nested_dict_value in fc9ef13
The result is that currently, value is set to None for when importing a nested dict value like rope_scaling.rope_type
Was there a specific reason to remove that logic from get_nested_dict_value ?

@tscholak
Copy link
Collaborator

tscholak commented Dec 6, 2024

Thanks @RaymondLi0
Additionally, exporting the checkpoint failed:

2024-12-06 17:19:24,749 [Rank 00] Saving export at iteration 1000
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'original_max_position_embeddings', 'factor', 'high_freq_factor', 'low_freq_factor'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'original_max_position_embeddings', 'factor', 'high_freq_factor', 'low_freq_factor'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'original_max_position_embeddings', 'factor', 'high_freq_factor', 'low_freq_factor'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'original_max_position_embeddings', 'factor', 'high_freq_factor', 'low_freq_factor'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'original_max_position_embeddings', 'factor', 'high_freq_factor', 'low_freq_factor'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'original_max_position_embeddings', 'factor', 'high_freq_factor', 'low_freq_factor'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'original_max_position_embeddings', 'factor', 'high_freq_factor', 'low_freq_factor'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'original_max_position_embeddings', 'factor', 'high_freq_factor', 'low_freq_factor'}
2024-12-06 17:19:24,773 [Rank 00] Saving tensors to /app/fast-llm-tutorial/experiment/export/llama/1000/model_0.safetensors
2024-12-06 17:20:00,939 [Rank 00] Saving tensors to /app/fast-llm-tutorial/experiment/export/llama/1000/model_1.safetensors
[E1206 17:20:24.336092093 ProcessGroupNCCL.cpp:588] [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=308230, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=60000) ran for 60004 milliseconds before timing out.
[E1206 17:20:24.336503904 ProcessGroupNCCL.cpp:1683] [PG  Rank 7] Exception (either an error or timeout) detected by watchdog at work: 308230, last enqueued NCCL work: 308230, last completed NCCL work: 308229.
[E1206 17:20:24.338932817 ProcessGroupNCCL.cpp:588] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=308230, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=60000) ran for 60008 milliseconds before timing out.
[E1206 17:20:24.339213368 ProcessGroupNCCL.cpp:1683] [PG  Rank 3] Exception (either an error or timeout) detected by watchdog at work: 308230, last enqueued NCCL work: 308230, last completed NCCL work: 308229.

@jlamypoirier
Copy link
Collaborator

@jlamypoirier The logic here

if converter.export_name is None or converter.export_name not in config

does not work because converter.export_name can be a tuple[str, ...].
That's why I had moved it to get_nested_dict_value in fc9ef13
The result is that currently, value is set to None for when importing a nested dict value like rope_scaling.rope_type
Was there a specific reason to remove that logic from get_nested_dict_value ?

I removed the None part so it could be more generic and be used elsewhere. The mistake here is I didn't remove or converter.export_name not in config, it's supposed to be handled through KeyError

@@ -241,7 +241,7 @@ def _compare_configs(config_ref, config_test):
@pytest.mark.depends(on=["test_converted_distributed"])
def test_load_pretrained_distributed_checkpoint():
config = TEST_ARCHITECTURE_CONFIG_CLS.from_dict(
yaml.safe_load((_CKPT_PATH / ".." / ".." / "config.yaml").open("r")), strict=False
yaml.safe_load((_CKPT_PATH / ".." / ".." / "config.yaml").open("r"))["model"]["base_model"], strict=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this one. Wasn't it working before? I'd have caught it in tests...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config comparison was broken I think (fixed in fa6440e), which is why this one wasn't caught

@jlamypoirier
Copy link
Collaborator

Bugs should be fixed, I but I still haven't checked beyond the tests. One of the tests fail (test_load_distributed_checkpoint_dp2) but it's very minor and unrelated to this PR, looks like something goes wrong with the default init_method_std when loading models.

@tscholak
Copy link
Collaborator

tscholak commented Dec 7, 2024

Thanks @RaymondLi0 and @jlamypoirier for making progress on this. Loading the Llama 3.1 model with the correct rope scaling type works fine now:

model:
  base_model:
    transformer:
      rotary:
        type: llama3
        theta: 500000.0

However, exporting the model to Llama 3.1 format still fails:

2024-12-07 04:15:10,553 [Rank 00] Saving checkpoint at iteration 1000
2024-12-07 04:15:31,546 [Rank 00] Saved checkpoint to /app/fast-llm-tutorial/experiment/checkpoint/1000
2024-12-07 04:15:31,548 [Rank 00] Saving export at iteration 1000
`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got 8192 and max_position_embeddings=2048
`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got 8192 and max_position_embeddings=2048
`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got 8192 and max_position_embeddings=2048
`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got 8192 and max_position_embeddings=2048
`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got 8192 and max_position_embeddings=2048
`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got 8192 and max_position_embeddings=2048
`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got 8192 and max_position_embeddings=2048
`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got 8192 and max_position_embeddings=2048
2024-12-07 04:15:31,601 [Rank 00] Saving tensors to /app/fast-llm-tutorial/experiment/export/llama/1000/model_0.safetensors
2024-12-07 04:16:07,999 [Rank 00] Saving tensors to /app/fast-llm-tutorial/experiment/export/llama/1000/model_1.safetensors
[E1207 04:16:31.569811685 ProcessGroupNCCL.cpp:588] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=308230, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=60000) ran for 60003 milliseconds before timing out.

@RaymondLi0
Copy link
Contributor Author

@tscholak I am not able to reproduce this issue, could you send more details? Which config did you run?

@tscholak
Copy link
Collaborator

tscholak commented Dec 9, 2024

Hi @RaymondLi0, I ran the Big QuickStart guide on LambdaLabs, so:

kubectl apply -f - <<EOF
apiVersion: "kubeflow.org/v1"
kind: "PyTorchJob"
metadata:
  name: "fast-llm-train"
spec:
  nprocPerNode: "8"
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      restartPolicy: Never
      template:
        spec:
          tolerations:
            - key: nvidia.com/gpu
              value: "true"
              operator: Equal
              effect: NoSchedule
          containers:
            - name: pytorch
              image: ghcr.io/servicenow/fast-llm:sha-41fa4bd
              resources:
                limits:
                  nvidia.com/gpu: 8
                  rdma/rdma_shared_device_a: 1
                  memory: "1024Gi"
                  cpu:
                requests:
                  nvidia.com/gpu: 8
                  rdma/rdma_shared_device_a: 1
                  memory: "1024Gi"
                  cpu: 128
              command:
                - /bin/bash
                - -c
                - |
                  torchrun --rdzv_backend=static \
                           --rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
                           --node_rank=\${RANK} \
                           --nproc_per_node=\${PET_NPROC_PER_NODE} \
                           --nnodes=\${PET_NNODES} \
                           --max_restarts=0 \
                           --rdzv_conf=timeout=3600 \
                           --no_python \
                           fast-llm train gpt \
                           --config fast-llm-tutorial/train-config.yaml
              env:
                - name: PYTHONHASHSEED
                  value: "0"
                - name: WANDB_API_KEY_PATH
                  value: "/app/fast-llm-tutorial/.wandb_api_key"
                - name: TORCH_NCCL_ASYNC_ERROR_HANDLING
                  value: "1"
                - name: NCCL_DEBUG
                  value: "INFO"
              securityContext:
                capabilities:
                  add:
                    - IPC_LOCK
              volumeMounts:
                - mountPath: /app/fast-llm-tutorial
                  name: fast-llm-inputs
                - mountPath: /dev/shm
                  name: dshm
          volumes:
            - name: fast-llm-inputs
              persistentVolumeClaim:
                claimName: pvc-fast-llm-tutorial
            - name: dshm
              emptyDir:
                medium: Memory
                sizeLimit: "1024Gi"
    Worker:
      replicas: 3
      restartPolicy: Never
      template:
        spec:
          tolerations:
            - key: nvidia.com/gpu
              value: "true"
              operator: Equal
              effect: NoSchedule
          containers:
            - name: pytorch
              image: ghcr.io/servicenow/fast-llm:sha-41fa4bd
              resources:
                limits:
                  nvidia.com/gpu: 8
                  rdma/rdma_shared_device_a: 1
                  memory: "1024Gi"
                  cpu:
                requests:
                  nvidia.com/gpu: 8
                  rdma/rdma_shared_device_a: 1
                  memory: "1024Gi"
                  cpu: 128
              command:
                - /bin/bash
                - -c
                - |
                  torchrun --rdzv_backend=static \
                           --rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
                           --node_rank=\${RANK} \
                           --nproc_per_node=\${PET_NPROC_PER_NODE} \
                           --nnodes=\${PET_NNODES} \
                           --max_restarts=0 \
                           --rdzv_conf=timeout=3600 \
                           --no_python \
                           fast-llm train gpt \
                           --config fast-llm-tutorial/train-config.yaml
              env:
                - name: PYTHONHASHSEED
                  value: "0"
                - name: WANDB_API_KEY_PATH
                  value: "/app/fast-llm-tutorial/.wandb_api_key"
                - name: TORCH_NCCL_ASYNC_ERROR_HANDLING
                  value: "1"
                - name: NCCL_DEBUG
                  value: "INFO"
              securityContext:
                capabilities:
                  add:
                    - IPC_LOCK
              volumeMounts:
                - mountPath: /app/fast-llm-tutorial
                  name: fast-llm-inputs
                - mountPath: /dev/shm
                  name: dshm
          volumes:
            - name: fast-llm-inputs
              persistentVolumeClaim:
                claimName: pvc-fast-llm-tutorial
            - name: dshm
              emptyDir:
                medium: Memory
                sizeLimit: "1024Gi"
EOF

with

training:
  train_iters: 100_000
  logs:
    interval: 10
  validation:
    iterations: 25
    interval: 1000
  checkpoint:
    interval: 1000
    keep: 5
  test_iters: 0
  export:
    format: llama
    interval: 1_000
  wandb:
    project_name: fast-llm-tutorial
    group_name: Big
    entity_name: null
batch:
  micro_batch_size: 2
  sequence_length: 4096
  batch_size: 512
data:
  format: file
  path: fast-llm-tutorial/dataset/fast_llm_dataset.json
  split: [99, 1, 0]
optimizer:
  weight_decay: 0.1
  beta_1: 0.9
  beta_2: 0.95
  learning_rate:
    base: 6.0e-04
    minimum: 6.0e-05
    decay_style: cosine
    decay_iterations: 100_000
    warmup_iterations: 2000
pretrained:
  format: llama
  path: fast-llm-tutorial/pretrained-model
  model_weights: yes
model:
  base_model:
    transformer:
      use_flash_attention: yes
    cross_entropy_impl: fused
  multi_stage:
    zero_stage: 2
  distributed:
    training_dtype: bf16
run:
  experiment_dir: fast-llm-tutorial/experiment

@RaymondLi0
Copy link
Contributor Author

Could you check if you get the same error on toolkit? I ran a very similar config on yul201 that exported the model successfully:

# @package _global_
defaults:
  - base
  - /job/spec: native_multinode

job:
  nodes: 2

cmd:
  experiment_name: llama_3.1_8b_cpt_nodes_4
  project:
    name: "fast_llm_recipes"
    version: "v1"

fast_llm:
  training:
    train_iters: 100  # TODO
    logs:
      interval: 10
    validation:
      iterations: 25
      interval: 1000
    checkpoint:
      interval: 1000
      keep: 5
    test_iters: 0
    export:  
      format: llama
      interval: 100  # TODO
  batch:
    micro_batch_size: 2
    sequence_length: 4096
    batch_size: 256
  data:
    format: file
    path: /mnt/datasets/owmath_dolmafw70_merged.json
    split: [99, 1, 0]  
  optimizer:  
    weight_decay: 0.1
    beta_1: 0.9
    beta_2: 0.95
    learning_rate:
      base: 6.0e-04
      minimum: 6.0e-05
      decay_style: cosine
      decay_iterations: 100_000
      warmup_iterations: 2000
  pretrained:
    format: llama  
    path: /mnt/checkpoints/Llama-3.1-8B/
    model_weights: yes
  model:
    base_model:
      transformer:
        use_flash_attention: yes
      cross_entropy_impl: fused
    multi_stage:
      zero_stage: 2
    distributed:
      training_dtype: bf16  
  run:
    experiment_dir: /mnt/checkpoints/fast_llm_dev/Llama-3.1-8B-nodes_2 

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed that checkpoint loading and saving works as designed for Llama-3.1-8B.

@tscholak tscholak merged commit 53286a3 into main Dec 9, 2024
4 checks passed
@tscholak tscholak deleted the raymond/llama3_rope branch December 9, 2024 20:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Llama 3.x rope scaling support
3 participants