Skip to content

Commit

Permalink
Enable 2D sharding (huggingface#17)
Browse files Browse the repository at this point in the history
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.

Let's say we have a 2D mesh (data, model) and data x model == num_devices:
1. input (data,, None, model)
2. embedding (model, data)
3. attn QKV (data, model)
4. attn O (model, data)
5. mlp gate, up (model, data)
6. mlp down (data, model)
7. activation (data,, None, model)
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.

TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
  • Loading branch information
alanwaketan committed Oct 27, 2023
1 parent 674ab35 commit d909935
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 3 deletions.
47 changes: 47 additions & 0 deletions examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ class ModelArguments:
)
},
)
spmd_2d_sharding: int = field(
default=0,
metadata={
"help": (
"Will apply XLA SPMD to 2D sharding, i.e., weights + activations, and spmd_2d_sharding specifies the model dimension"
)
},
)

def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
Expand Down Expand Up @@ -297,6 +305,7 @@ def main():
training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding
training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding
training_args.spmd_tensor_sharding = model_args.spmd_tensor_sharding
training_args.spmd_2d_sharding = model_args.spmd_2d_sharding

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
Expand Down Expand Up @@ -469,6 +478,8 @@ def main():
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

# Pass the 2d sharding config to the actual model.
config.spmd_2d_sharding = model_args.spmd_2d_sharding
if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
Expand Down Expand Up @@ -539,6 +550,42 @@ def main():
else:
assert len(param.shape) == 2
xs.mark_sharding(param, mesh, range(len(param.shape)))
elif model_args.spmd_2d_sharding > 0:
print('Applying 2D sharding to all parameters')
for name, param in model.named_parameters():
# Apply 2D sharding:
# embedding (model, data)
# attn QKV (data, model)
# attn O (model, data)
# mlp gate, up (model, data)
# mlp down (data, model)
print('> Sharding tensor', name, param.shape)
mod = model_args.spmd_2d_sharding
data = num_devices // mod
assert mod * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, mod))
model_data_mesh = xs.HybridMesh(ici_mesh_shape=(mod, data))

# We don't care about layernorm's weights, and
# LLaMA doesn't use biases.
if len(param.shape) == 1:
continue

if 'embed_tokens' in name:
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name:
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
elif 'o_proj' in name:
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
elif 'gate_proj' in name or 'up_proj' in name:
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
elif 'down_proj' in name:
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
elif 'lm_head' in name: # Not sure what this is but has the same shape as embed_tokens
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))

import torch_xla
print(torch_xla._XLAC._get_xla_sharding_spec(param))

# Preprocessing the datasets.
# First we tokenize all the texts.
Expand Down
37 changes: 37 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
# For PyTorch/XLA's SPMD 2D sharding
self.spmd_2d_sharding = config.spmd_2d_sharding
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
Expand Down Expand Up @@ -540,6 +542,22 @@ def forward(
if not output_attentions:
attn_weights = None

# Apply 2D sharding:
# activation (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding activations', attn_output.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))

return attn_output, attn_weights, past_key_value


Expand Down Expand Up @@ -935,6 +953,9 @@ class LlamaModel(LlamaPreTrainedModel):

def __init__(self, config: LlamaConfig):
super().__init__(config)
# For PyTorch/XLA's SPMD 2D sharding
self.spmd_2d_sharding = config.spmd_2d_sharding

self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

Expand Down Expand Up @@ -1015,7 +1036,23 @@ def forward(
)

# embed positions
# Is this the input to the model?
hidden_states = inputs_embeds
# Apply 2D sharding:
# input (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding hidden_states', hidden_states.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,10 +1427,11 @@ def _xla_sharded_dataloader(self, dataloader):
if self.args.spmd_batch_sharding:
mesh = xs.Mesh(device_ids, (num_devices, 1))
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
elif self.args.spmd_tensor_sharding > 0:
tensor = self.args.spmd_tensor_sharding
elif self.args.spmd_tensor_sharding > 0 or self.args.spmd_2d_sharding > 0:
assert self.args.spmd_tensor_sharding == 0 or self.args.spmd_2d_sharding == 0
tensor = self.args.spmd_tensor_sharding + self.args.spmd_2d_sharding
fsdp = num_devices // tensor
mesh = xs.Mesh(device_ids, (fsdp, tensor))
mesh = xs.HybridMesh(ici_mesh_shape=(fsdp, tensor))
partition_spec = (0, None)
sharding_spec = xs.ShardingSpec(mesh, partition_spec)

Expand Down

0 comments on commit d909935

Please sign in to comment.