Skip to content

Commit

Permalink
PEFT + TP support (#620)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Jun 14, 2024
1 parent b86b48f commit 096c964
Show file tree
Hide file tree
Showing 20 changed files with 1,261 additions and 463 deletions.
5 changes: 1 addition & 4 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
replace_class_in_inheritance_hierarchy,
)
from ..utils.misc import args_and_kwargs_to_kwargs_only, is_main_worker
from ..utils.model_utils import get_tied_parameters_dict, tie_parameters
from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla
from ..utils.torch_xla_and_neuronx_initialization import check_neuron_cc_flags_for_model
from .optimizer import NeuronAcceleratedOptimizer
Expand All @@ -57,9 +58,7 @@
AutocastBackend,
ModelParallelismPlugin,
NeuronDistributedType,
get_tied_parameters_dict,
patch_accelerate_is_torch_xla_available,
tie_parameters,
)
from .utils.misc import (
apply_activation_checkpointing,
Expand Down Expand Up @@ -483,8 +482,6 @@ def prepare_model(
module._use_flash_attention_2 = False

if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
if isinstance(model, NeuronPeftModel):
raise NotImplementedError("PEFT is not supported with model parallelism for now.")
model = self._prepare_model_for_mp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
)
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
ModelParallelismPlugin,
NeuronDistributedType,
)
from .misc import get_tied_parameters_dict, patch_accelerate_is_torch_xla_available, tie_parameters
from .misc import patch_accelerate_is_torch_xla_available
46 changes: 0 additions & 46 deletions optimum/neuron/accelerate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from transformers.modeling_utils import get_parameter_dtype

from ....utils import logging
from ...distributed.utils import named_parameters
from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere
from ...utils.patching import Patcher
from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla
Expand Down Expand Up @@ -151,51 +150,6 @@ def wrapper(*args, **kwargs):
return wrapper.__get__(orig_self)


@requires_neuronx_distributed
def get_tied_parameters_dict(model: Union["torch.nn.Module", "NxDPPModel"]) -> Dict[str, str]:
from neuronx_distributed.pipeline import NxDPPModel

unique_parameters = {}
tied_parameters = {}
if isinstance(model, NxDPPModel):
module = model.local_module
else:
module = model
for name, param in named_parameters(module, remove_duplicate=False):
if param in unique_parameters:
tied_parameter_name = unique_parameters[param]
tied_parameters[name] = tied_parameter_name
else:
unique_parameters[param] = name
return tied_parameters


@requires_neuronx_distributed
def tie_parameters(model: Union["torch.nn.Module", "NxDPPModel"], tied_parameters_dict: Dict[str, str]):
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model, NxDPPModel):
module = model.local_module
else:
module = model

for param_to_tie_name, param_name in tied_parameters_dict.items():
param_to_tie_name = param_to_tie_name.rsplit(".", maxsplit=1)

param_to_tie_parent_module = (
module if len(param_to_tie_name) == 1 else module.get_submodule(param_to_tie_name[0])
)
param_to_tie = getattr(param_to_tie_parent_module, param_to_tie_name[1])

param_name = param_name.rsplit(".", maxsplit=1)
parent_module = module if len(param_name) == 1 else module.get_submodule(param_name[0])
param = getattr(parent_module, param_name[1])

if param_to_tie is not param:
del param_to_tie
setattr(param_to_tie_parent_module, param_to_tie_name[1], param)


# TODO: @michaelbenayoun
# Needs to make it work in the general case or be deleted and only use `apply_activation_checkpointing`.
@requires_torch_xla
Expand Down
202 changes: 122 additions & 80 deletions optimum/neuron/distributed/base.py

Large diffs are not rendered by default.

54 changes: 41 additions & 13 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,30 @@

import torch
from transformers.modeling_utils import shard_checkpoint
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME

from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_peft_available,
)

from ..utils.peft_utils import ADAPTER_MODEL_PARALLEL_SHARDS_DIR_NAME
from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors
from .utils import MODEL_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indices_for_rank


if is_peft_available():
from peft.utils.constants import (
SAFETENSORS_WEIGHTS_NAME as PEFT_SAFETENSORS_WEIGHTS_NAME,
)
from peft.utils.constants import (
WEIGHTS_NAME as PEFT_WEIGHTS_NAME,
)
else:
PEFT_SAFETENSORS_WEIGHTS_NAME = PEFT_WEIGHTS_NAME = ""


def create_gqa_query_or_output_projection_weight_from_full_weight(
full_weight: torch.Tensor,
tp_size: int,
Expand Down Expand Up @@ -129,18 +147,9 @@ def consolidate_tensor_parallel_checkpoints(


@requires_neuronx_distributed
def consolidate_model_parallel_checkpoints(checkpoint_dir: Union[str, Path]) -> Dict[str, "torch.Tensor"]:
def consolidate_model_parallel_checkpoints(checkpoint_dir: Path) -> Dict[str, "torch.Tensor"]:
from neuronx_distributed.parallel_layers.checkpointing import _xser_load

if not isinstance(checkpoint_dir, Path):
checkpoint_dir = Path(checkpoint_dir)

if checkpoint_dir.name != MODEL_PARALLEL_SHARDS_DIR_NAME:
if (checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME).is_dir():
checkpoint_dir = checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME
else:
raise ValueError(f"Could not find the tensor parallel shards from {checkpoint_dir}")

model_checkpoint_dir = checkpoint_dir / "model"

# Case 1: the checkpoint was saved with xser.
Expand Down Expand Up @@ -191,14 +200,33 @@ def consolidate_model_parallel_checkpoints_to_unified_checkpoint(
):
from safetensors.torch import save_file

if not isinstance(checkpoint_dir, Path):
checkpoint_dir = Path(checkpoint_dir)

if checkpoint_dir.name not in [MODEL_PARALLEL_SHARDS_DIR_NAME, ADAPTER_MODEL_PARALLEL_SHARDS_DIR_NAME]:
if (checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME).is_dir():
checkpoint_dir = checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME
elif (checkpoint_dir / ADAPTER_MODEL_PARALLEL_SHARDS_DIR_NAME).is_dir():
checkpoint_dir = checkpoint_dir / ADAPTER_MODEL_PARALLEL_SHARDS_DIR_NAME
else:
raise ValueError(f"Could not find the tensor parallel shards from {checkpoint_dir}")

if not isinstance(output_dir, Path):
output_dir = Path(output_dir)

is_adapter_model = checkpoint_dir.name == ADAPTER_MODEL_PARALLEL_SHARDS_DIR_NAME
if is_adapter_model:
safe_weights_name = PEFT_SAFETENSORS_WEIGHTS_NAME
weights_name = PEFT_WEIGHTS_NAME
else:
safe_weights_name = SAFE_WEIGHTS_NAME
weights_name = WEIGHTS_NAME

output_dir.mkdir(parents=True, exist_ok=True)

state_dict = consolidate_model_parallel_checkpoints(checkpoint_dir)
shards, index = shard_checkpoint(
state_dict, weights_name=SAFE_WEIGHTS_NAME if save_format == "safetensors" else WEIGHTS_NAME
state_dict, weights_name=safe_weights_name if save_format == "safetensors" else weights_name
)
for shard_file, shard in shards.items():
if save_format == "safetensors":
Expand Down
15 changes: 13 additions & 2 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
from torch.nn.modules.loss import _WeightedLoss
from transformers.utils import is_peft_available

from ...utils import NormalizedConfigManager, logging
from ..utils import patch_everywhere, patch_within_function
Expand Down Expand Up @@ -246,12 +247,22 @@ def _transform(
)

embedding_layer = layer.get_submodule(embedding_name)
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer

if isinstance(embedding_layer, BaseTunerLayer):
num_embeddings = embedding_layer.get_base_layer().num_embeddings
else:
num_embeddings = embedding_layer.num_embeddings
else:
num_embeddings = embedding_layer.num_embeddings

tp_size = parallel_state.get_tensor_model_parallel_size()
if embedding_layer.num_embeddings % tp_size != 0:
if num_embeddings % tp_size != 0:
if is_main_worker():
logger.warning(
f"Embedding parallelization for TP was skipped because the tensor parallel size ({tp_size}) does not "
f"divide the number of embeddings ({embedding_layer.num_embeddings})"
f"divide the number of embeddings ({num_embeddings})"
)
return layer

Expand Down
14 changes: 11 additions & 3 deletions optimum/neuron/distributed/parallelizers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from transformers import PreTrainedModel

from ..utils.peft_utils import NeuronPeftModel
from ..utils.require_utils import requires_neuronx_distributed
from .base import Parallelizer

Expand Down Expand Up @@ -71,19 +72,24 @@ def get_supported_model_types(cls) -> List[str]:

@classmethod
@requires_neuronx_distributed
def _get_model_type(cls, model_type_or_model: Union[str, PreTrainedModel]) -> str:
def _get_model_type(cls, model_type_or_model: Union[str, PreTrainedModel, NeuronPeftModel]) -> str:
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model_type_or_model, NxDPPModel):
model_type_or_model = model_type_or_model.original_torch_module
elif isinstance(model_type_or_model, NeuronPeftModel):
model_type_or_model = model_type_or_model.get_base_model()

if isinstance(model_type_or_model, PreTrainedModel):
model_type = model_type_or_model.config.model_type
else:
model_type = model_type_or_model
return model_type

@classmethod
def is_model_supported(cls, model_type_or_model: Union[str, PreTrainedModel]) -> Tuple[bool, bool, bool]:
def is_model_supported(
cls, model_type_or_model: Union[str, PreTrainedModel, NeuronPeftModel]
) -> Tuple[bool, bool, bool]:
"""
Returns a tuple of 3 booleans where:
- The first element indicates if tensor parallelism can be used for this model,
Expand All @@ -106,7 +112,9 @@ def is_model_supported(cls, model_type_or_model: Union[str, PreTrainedModel]) ->
return (for_tp, for_sp, for_pp)

@classmethod
def parallelizer_for_model(cls, model_type_or_model: Union[str, PreTrainedModel]) -> Type[Parallelizer]:
def parallelizer_for_model(
cls, model_type_or_model: Union[str, PreTrainedModel, NeuronPeftModel]
) -> Type[Parallelizer]:
"""
Returns the parallelizer class associated to the model.
Expand Down
Loading

0 comments on commit 096c964

Please sign in to comment.