Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dataclass args for accelerated MoE tuning #390

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

willmj
Copy link
Collaborator

@willmj willmj commented Nov 15, 2024

Description of the change

This PR adds one dataclass argument to enable accelerted moe for sft_trainer.py, via the new fms-acceleration accelerated-moe plugin and allows for accelerated MoE full-finetuning with the --fast_moe flag. --fast_moe enables a technique to train Mixture of Expert (MoE) models in parallel instead of sequentially.
With this flag, we expect major speedup in train time and decrease in memory usage on Mixture of Expert models.

Framework Config EP Degree (parameter) Model Train Runtime Speedup Memory Usage Memory Savings
none N/A granite 3b a800 1 2371.93 base 71199 base
Scatter MoE 1 granite 3b a800 1 742.739 3.19 71187 1.0
Scatter MoE + Padding Free 1 granite 3b a800 1 631.976 3.75 48401 0.68
Scatter MoE + Padding Free + foak 1 granite 3b a800 1 615.453 3.85 42651 0.6
none N/A mixtral 8x7b 8 4180.95 base 65607 base
Scatter MoE 8 mixtral 8b7x 8 1071.2 3.9 52004.8 0.79
Scatter MoE + Padding Free + foak 8 mixtral 8x7b 8 1043.67 4.01 51961.2 0.79

Related issue number

How to verify the PR

This PR is a work-in-progress and requires more testing, and the official release of fms-acceleration-moe

  • To verify, run a tuning job with fast_moe.
  • Run a tuning job with other plugins added on top of fast_moe
  • Ensure that incorrect parameters result in failures
  • Ensure that non-MoE models cannot be trained with this plugin set

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Copy link

Thanks for making a pull request! 😃
One of the maintainers will review and advise on the next steps.

@github-actions github-actions bot added the feat label Nov 15, 2024
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
@willmj
Copy link
Collaborator Author

willmj commented Nov 21, 2024

Tested using new flag on granite 3 3b MoE, inference up next

Regular MOE tuning

Tested this branch without fast_moe

      {
          "model_name_or_path": "/ibm_dmf_lakehouse/models/base_training/shared/granite-3.0-3b-a800m-base/r240924a",
          "training_data_path": "/testing/tuning/input/cc_tone_sft_format_1000_train.json",
          "output_dir": "/testing/tuning/output/granite-3b-moe/ft/20241120_1014-tone",
          "save_model_dir": "/testing/tuning/output/granite-3b-moe/ft/20241120_1014-tone/save_model",
          "num_train_epochs": 10.0,
          "per_device_train_batch_size": 2,
          "gradient_accumulation_steps": 1,
          "learning_rate": 1e-5,
          "response_template": "\n### Response:",
          "dataset_text_field": "output"
      }

Training logs:

{'loss': 0.8331, 'grad_norm': 364.0, 'learning_rate': 9e-06, 'epoch': 1.0}
{'loss': 0.4259, 'grad_norm': 0.10986328125, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.0}
{'loss': 0.1667, 'grad_norm': 25.25, 'learning_rate': 7e-06, 'epoch': 3.0}
{'loss': 0.0304, 'grad_norm': 21.625, 'learning_rate': 6e-06, 'epoch': 4.0}
{'loss': 0.0023, 'grad_norm': 0.005828857421875, 'learning_rate': 5e-06, 'epoch': 5.0}
{'loss': 0.0004, 'grad_norm': 0.005157470703125, 'learning_rate': 4.000000000000001e-06, 'epoch': 6.0}
{'loss': 0.0001, 'grad_norm': 0.0038604736328125, 'learning_rate': 3e-06, 'epoch': 7.0}
{'loss': 0.0001, 'grad_norm': 0.000469207763671875, 'learning_rate': 2.0000000000000003e-06, 'epoch': 8.0}
{'loss': 0.0001, 'grad_norm': 0.004547119140625, 'learning_rate': 1.0000000000000002e-06, 'epoch': 9.0}
{'loss': 0.0001, 'grad_norm': 0.01324462890625, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 5311.528, 'train_samples_per_second': 1.883, 'train_steps_per_second': 0.941, 'train_loss': 0.1459229184500873, 'epoch': 10.0}

Location: /testing/tuning/output/granite-3b-moe/ft/20241121_1314-tone/save_model

Fast MOE

And with fast_moe:

      {
          "model_name_or_path": "/ibm_dmf_lakehouse/models/base_training/shared/granite-3.0-3b-a800m-base/r240924a",
          "training_data_path": "/testing/tuning/input/cc_tone_sft_format_1000_train.json",
          "output_dir": "/testing/tuning/output/granite-3b-moe/ft/20241121_1014-tone-FAST",
          "save_model_dir": "/testing/tuning/output/granite-3b-moe/ft/20241121_1014-tone-FAST/save_model",
          "num_train_epochs": 10.0,
          "per_device_train_batch_size": 2,
          "gradient_accumulation_steps": 1,
          "learning_rate": 1e-5,
          "response_template": "\n### Response:",
          "dataset_text_field": "output",
          "fast_moe": 1
      }

Training logs

{'loss': 0.4279, 'grad_norm': 0.076171875, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.0}
{'loss': 0.1377, 'grad_norm': 3.78125, 'learning_rate': 7e-06, 'epoch': 3.0}
{'loss': 0.0384, 'grad_norm': 0.81640625, 'learning_rate': 6e-06, 'epoch': 4.0}
{'loss': 0.0031, 'grad_norm': 0.003997802734375, 'learning_rate': 5e-06, 'epoch': 5.0}
{'loss': 0.0006, 'grad_norm': 0.002044677734375, 'learning_rate': 4.000000000000001e-06, 'epoch': 6.0}
{'loss': 0.0002, 'grad_norm': 0.0032196044921875, 'learning_rate': 3e-06, 'epoch': 7.0}
{'loss': 0.0001, 'grad_norm': 0.002288818359375, 'learning_rate': 2.0000000000000003e-06, 'epoch': 8.0}
{'loss': 0.0001, 'grad_norm': 0.0087890625, 'learning_rate': 1.0000000000000002e-06, 'epoch': 9.0}
{'loss': 0.0001, 'grad_norm': 0.0115966796875, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 2140.2943, 'train_samples_per_second': 4.672, 'train_steps_per_second': 2.336, 'train_loss': 0.14420232288464904, 'epoch': 10.0}

Location: /testing/tuning/output/granite-3b-moe/ft/20241121_1315-tone-FAST/save_model

Results

We see a 2.48x speedup

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
@fabianlim
Copy link
Collaborator

fabianlim commented Nov 22, 2024

@willmj In the original PR we reported the benches where the batch sizes are different, but the numbers that you report here are around that ballpark.

c.f., the numbers in the bench in a table (from the original PR)

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
@willmj
Copy link
Collaborator Author

willmj commented Dec 9, 2024

After running checkpoint utils on the branch Fabian created for safetensors, vLLM inference ran as expected:

% grpcurl -plaintext -proto ./proto/generation.proto -d "{\"params\":{\"method\":\"GREEDY\", \"stopping\": {\"max_new_tokens\": 128}}, \"requests\": [{\"text\":\"### Text: @sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas\n\n### Label:\"}]}" localhost:8033 fmaas.GenerationService/Generate
{
  "responses": [
    {
      "generatedTokenCount": 128,
      "text": " sad, frustrated, anxious, anxious, frustrated, sad, anxious, anxious, frustrated, sad, frustrated, anxious, frustrated, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad,",
      "inputTokenCount": 38,
      "stopReason": "MAX_TOKENS"
    }
  ]
}

Post-processing completed with this script (thanks again Fabian!):

from fms_acceleration_moe.utils.checkpoint_utils import get_state_dict_from_safe_checkpoint, recover_original_state_dict_from_checkpoint, save_single_safetensor
from safetensors.torch import save_file
from transformers.utils import SAFE_WEIGHTS_NAME, CONFIG_NAME
import os, shutil, json

checkpoint_dir = '<scattermoe-checkpoing-dir>'

output_dir = '<output-dir>'
pretrained_model_name_or_path = '<original-model-dir>'

config_file = os.path.join(checkpoint_dir, CONFIG_NAME)
target_config_file = os.path.join(output_dir, CONFIG_NAME)
if os.path.exists(config_file):
    shutil.copyfile(config_file, target_config_file)

    if not pretrained_model_name_or_path:
        with open(target_config_file) as f:
            pretrained_model_name_or_path = json.load(f).get("_name_or_path")


sd = get_state_dict_from_safe_checkpoint(checkpoint_dir)

sd = recover_original_state_dict_from_checkpoint(
    sd, pretrained_model_name_or_path
)


save_single_safetensor(
    {k: v.contiguous() for k, v in sd.items()},
    output_dir,
    metadata={"format": "pt"},
)


from transformers import AutoModelForCausalLM

# test if we can load the converted state dict
model = AutoModelForCausalLM.from_pretrained(output_dir)

FastMOE model saved in: /testing/tuning/output/granite-3b-moe/ft/20241121_1315-tone-FAST/save_model
Reconstructed SD model saved in: /testing/tuning/output/granite-3b-moe/ft/20241121_1315-tone-FAST/standard-sd

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
@willmj willmj marked this pull request as ready for review December 9, 2024 18:49
@willmj willmj requested a review from kmehant as a code owner December 9, 2024 18:49
@fabianlim fabianlim changed the title feat: [WIP] dataclass args for accelerated MoE tuning feat: dataclass args for accelerated MoE tuning Dec 10, 2024
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants