Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExpertParallel Mixture-of-Experts Plugin #99

Merged
merged 21 commits into from
Nov 13, 2024
Merged

Add ExpertParallel Mixture-of-Experts Plugin #99

merged 21 commits into from
Nov 13, 2024

Conversation

fabianlim
Copy link
Contributor

@fabianlim fabianlim commented Nov 3, 2024

This PR will supercede #69.

This PR will add a new accelerated-moe plugin that is triton-only

TODO:

Performance

For ibm-granite/granite-3.0-3b-a800m-instruct and Mixtral-8x7B-Instruct-v0.1

  • effective batch size 128
  • bfloat16 no mixed precision
  • we disabled the torch memory logging to get more competitive runtimes
  • framework_config = None: FSDP
  • moe-scattermoe-granite-ep1: MoE world_size = 1
  • moe-scattermoe-granite-ep2: MoE world_size = 2
model_name_or_path num_gpus framework_config mem_nvidia_mem_reserved train_runtime mem util speedup
ibm-granite/granite-3.0-3b-a800m-instruct 1 none 71199 2371.93 baseline baseline
ibm-granite/granite-3.0-3b-a800m-instruct 1 moe-scattermoe-granite-ep1 71187 742.739 1.0 3.19
ibm-granite/granite-3.0-3b-a800m-instruct 1 moe-scattermoe-granite-ep1-padding-free 48401 631.976 0.68 3.75
ibm-granite/granite-3.0-3b-a800m-instruct 1 moe-scattermoe-granite-ep1-padding-free-foak 42651 615.453 0.6 3.85
ibm-granite/granite-3.0-3b-a800m-instruct 2 none 46829 1355.71 baseline baseline
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep1 52503 485.51 1.12 2.79
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep1-padding-free 42452 454.344 0.91 2.98
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep1-padding-free-foak 37743 433.481 0.81 3.13
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep2 40193 577.216 0.86 2.35
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep2-padding-free 31012 546.507 0.66 2.48
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep2-padding-free-foak 26075 524.775 0.56 2.58
ibm-granite/granite-3.0-3b-a800m-instruct 4 none 37996 708.391 baseline baseline
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep1 51145 262.957 1.35 2.69
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep1-padding-free 38560 241.297 1.01 2.94
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep1-padding-free-foak 35153 232.043 0.93 3.05
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep2 40878.5 300.285 1.08 2.36
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep2-padding-free 28133 283.544 0.74 2.5
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep2-padding-free-foak 24665.5 274.126 0.65 2.58
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep4 31777.5 307.126 0.84 2.31
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep4-padding-free 21585.5 284.608 0.57 2.49
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep4-padding-free-foak 18368 278.125 0.48 2.55
mistralai/Mixtral-8x7B-Instruct-v0.1 8 none 65607.2 4180.95 baseline baseline
mistralai/Mixtral-8x7B-Instruct-v0.1 8 moe-scattermoe-granite-ep8 52004.8 1071.2 0.79 3.9
mistralai/Mixtral-8x7B-Instruct-v0.1 8 moe-scattermoe-granite-ep8-foak 51961.2 1043.67 0.79 4.01

Resumption

non-sharded checkpoints: Tested resumption on 2 devices for expert size 1 and 2

reader = dcp.FileSystemReader("tmp3/checkpoint-10/pytorch_model_fsdp_0")
metadata.state_dict_metadata['model.model.layers.1.block_sparse_moe.w1.weight']
TensorStorageMetadata(properties=TensorProperties(dtype=torch.bfloat16, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False), size=torch.Size([40, 1536, 512]), chunks=[ChunkStorageMetadata(offsets=torch.Size([0, 0, 0]), sizes=torch.Size([20, 1536, 512])), ChunkStorageMetadata(offsets=torch.Size([20, 0, 0]), sizes=torch.Size([20, 1536, 512]))])

image

Also for sharded checkpoints (mixtral): tested resumption

reader = dcp.FileSystemReader("tmp3/checkpoint-10/pytorch_model_fsdp_0")
metadata.state_dict_metadata['model.model.layers.1.block_sparse_moe.w1.weight']

TensorStorageMetadata(properties=TensorProperties(dtype=torch.bfloat16, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False), size=torch.Size([8, 4096, 14336]), chunks=[ChunkStorageMetadata(offsets=torch.Size([0, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([1, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([2, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([3, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([4, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([5, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([6, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([7, 0, 0]), sizes=torch.Size([1, 4096, 14336]))])

Handling the State Dict

we have a convinience function restore_scattermoe_checkpoint_to_orig to load the DCP checkpoint, and optionally convert back to original if the pretrained checkpoint is provided.

from fms_acceleration_moe.utils.checkpoint_utils import restore_scattermoe_checkpoint_to_orig
from fms_acceleration_moe.utils import prepare_scattermoe
from transformers import AutoModelForCausalLM

MODEL = 'ibm-granite/granite-3.0-3b-a800m-instruct'
CKPT = "tmp2/checkpoint-50/pytorch_model_fsdp_0"

# load the model, convert to scattermoe
model = AutoModelForCausalLM.from_pretrained(MODEL, device_map='cuda')
prepare_scattermoe(
    model,
    checkpoint_name_or_path=MODEL,
    rank=0,
    world_size=1,
    ep_degree=1,
    mixed_precision=False,  # Currently this is hardcoded to OFF
)

# dcp checkpoint
sd = restore_scattermoe_checkpoint_to_orig(CKPT)
model.load_state_dict(sd)

# load the original model
model2 = AutoModelForCausalLM.from_pretrained(MODEL, device_map='cuda')
# use the utility to convert the original checkpoint
sd = restore_scattermoe_checkpoint_to_orig(CKPT, pretrained_model_name_or_path=MODEL)
model2.load_state_dict(sd)


Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim fabianlim marked this pull request as draft November 3, 2024 07:46
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim fabianlim force-pushed the refactor/moe branch 3 times, most recently from f412899 to bccd967 Compare November 5, 2024 01:52
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim fabianlim marked this pull request as ready for review November 6, 2024 10:30
@fabianlim fabianlim requested a review from anhuong November 7, 2024 02:28
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim fabianlim merged commit 5b35eae into main Nov 13, 2024
7 checks passed
@fabianlim fabianlim deleted the refactor/moe branch November 13, 2024 02:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant