From d9099357b0b284f65b727da05a9af8e7a71ef07e Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 1 Aug 2023 15:50:49 -0700 Subject: [PATCH] Enable 2D sharding (#17) Summary: This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy. Let's say we have a 2D mesh (data, model) and data x model == num_devices: 1. input (data,, None, model) 2. embedding (model, data) 3. attn QKV (data, model) 4. attn O (model, data) 5. mlp gate, up (model, data) 6. mlp down (data, model) 7. activation (data,, None, model) Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated. TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently. --- examples/pytorch/language-modeling/run_clm.py | 47 +++++++++++++++++++ .../models/llama/modeling_llama.py | 37 +++++++++++++++ src/transformers/trainer.py | 7 +-- 3 files changed, 88 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index ee8d2c7348f3ef..af73b0e7fac3de 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -189,6 +189,14 @@ class ModelArguments: ) }, ) + spmd_2d_sharding: int = field( + default=0, + metadata={ + "help": ( + "Will apply XLA SPMD to 2D sharding, i.e., weights + activations, and spmd_2d_sharding specifies the model dimension" + ) + }, + ) def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): @@ -297,6 +305,7 @@ def main(): training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding training_args.spmd_tensor_sharding = model_args.spmd_tensor_sharding + training_args.spmd_2d_sharding = model_args.spmd_2d_sharding # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. @@ -469,6 +478,8 @@ def main(): "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) + # Pass the 2d sharding config to the actual model. + config.spmd_2d_sharding = model_args.spmd_2d_sharding if model_args.model_name_or_path: torch_dtype = ( model_args.torch_dtype @@ -539,6 +550,42 @@ def main(): else: assert len(param.shape) == 2 xs.mark_sharding(param, mesh, range(len(param.shape))) + elif model_args.spmd_2d_sharding > 0: + print('Applying 2D sharding to all parameters') + for name, param in model.named_parameters(): + # Apply 2D sharding: + # embedding (model, data) + # attn QKV (data, model) + # attn O (model, data) + # mlp gate, up (model, data) + # mlp down (data, model) + print('> Sharding tensor', name, param.shape) + mod = model_args.spmd_2d_sharding + data = num_devices // mod + assert mod * data == num_devices + data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, mod)) + model_data_mesh = xs.HybridMesh(ici_mesh_shape=(mod, data)) + + # We don't care about layernorm's weights, and + # LLaMA doesn't use biases. + if len(param.shape) == 1: + continue + + if 'embed_tokens' in name: + xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) + elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name: + xs.mark_sharding(param, data_model_mesh, range(len(param.shape))) + elif 'o_proj' in name: + xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) + elif 'gate_proj' in name or 'up_proj' in name: + xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) + elif 'down_proj' in name: + xs.mark_sharding(param, data_model_mesh, range(len(param.shape))) + elif 'lm_head' in name: # Not sure what this is but has the same shape as embed_tokens + xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) + + import torch_xla + print(torch_xla._XLAC._get_xla_sharding_spec(param)) # Preprocessing the datasets. # First we tokenize all the texts. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4c31729337ddd1..f387d1b2266e2b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -392,6 +392,8 @@ class LlamaAttention(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.config = config + # For PyTorch/XLA's SPMD 2D sharding + self.spmd_2d_sharding = config.spmd_2d_sharding self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -540,6 +542,22 @@ def forward( if not output_attentions: attn_weights = None + # Apply 2D sharding: + # activation (data,, None, model) + import torch_xla.core.xla_model as xm + import torch_xla.experimental.xla_sharding as xs + import torch_xla.runtime as xr + import torch_xla + num_devices = xr.global_runtime_device_count() + device_ids = torch.arange(num_devices) + print('> Sharding activations', attn_output.shape) + model = self.spmd_2d_sharding + data = num_devices // model + assert model * data == num_devices + data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model)) + xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2)) + print(torch_xla._XLAC._get_xla_sharding_spec(attn_output)) + return attn_output, attn_weights, past_key_value @@ -935,6 +953,9 @@ class LlamaModel(LlamaPreTrainedModel): def __init__(self, config: LlamaConfig): super().__init__(config) + # For PyTorch/XLA's SPMD 2D sharding + self.spmd_2d_sharding = config.spmd_2d_sharding + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -1015,7 +1036,23 @@ def forward( ) # embed positions + # Is this the input to the model? hidden_states = inputs_embeds + # Apply 2D sharding: + # input (data,, None, model) + import torch_xla.core.xla_model as xm + import torch_xla.experimental.xla_sharding as xs + import torch_xla.runtime as xr + import torch_xla + num_devices = xr.global_runtime_device_count() + device_ids = torch.arange(num_devices) + print('> Sharding hidden_states', hidden_states.shape) + model = self.spmd_2d_sharding + data = num_devices // model + assert model * data == num_devices + data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model)) + xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2)) + print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states)) if self.gradient_checkpointing and self.training: if use_cache: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bd0d23d18da38d..5738abea058832 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1427,10 +1427,11 @@ def _xla_sharded_dataloader(self, dataloader): if self.args.spmd_batch_sharding: mesh = xs.Mesh(device_ids, (num_devices, 1)) sharding_spec = xs.ShardingSpec(mesh, (0, 1)) - elif self.args.spmd_tensor_sharding > 0: - tensor = self.args.spmd_tensor_sharding + elif self.args.spmd_tensor_sharding > 0 or self.args.spmd_2d_sharding > 0: + assert self.args.spmd_tensor_sharding == 0 or self.args.spmd_2d_sharding == 0 + tensor = self.args.spmd_tensor_sharding + self.args.spmd_2d_sharding fsdp = num_devices // tensor - mesh = xs.Mesh(device_ids, (fsdp, tensor)) + mesh = xs.HybridMesh(ici_mesh_shape=(fsdp, tensor)) partition_spec = (0, None) sharding_spec = xs.ShardingSpec(mesh, partition_spec)