Skip to content

Commit

Permalink
add meta onDevice support for LLAMA2 (#4147)
Browse files Browse the repository at this point in the history
Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com>
  • Loading branch information
dc3671 and molly-smith authored Aug 24, 2023
1 parent f690319 commit 0712e29
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
8 changes: 6 additions & 2 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ def copy(self, dst, src, int8=False, allocate_tensor=False):

class Loading():

def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm"]
return module.__class__ in load_layers or module._get_name() in load_layer_names

def load_buffer(module, state_dict, prefix):
for name in module._buffers.keys():
if module._buffers[name].data.is_meta:
Expand Down Expand Up @@ -399,8 +404,7 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
else:
class_name = prev_class_name + '.' + prev_name
checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.'
if (child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm]
or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and self.state_dict is not None:
if Loading.is_load_module(child) and self.state_dict is not None:
if any(checking_key in item for item in self.state_dict):
Loading.load(child, self.state_dict, checking_key, self.mp_group)
else:
Expand Down
10 changes: 1 addition & 9 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading

from deepspeed import comm as dist
from torch import nn

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -595,12 +594,6 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di
Returns:
Modified ``model``.
"""
try:
import transformers
OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
except:
OPTLearnedPositionalEmbedding = None
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm, OPTLearnedPositionalEmbedding]
for name, child in model.named_children():
if child.__class__ in policies:
replaced_module = policies[child.__class__][0](child,
Expand All @@ -616,8 +609,7 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di
layer_id += 1
else:
checking_key = prefix + name + '.'
if (child.__class__ in load_layers
or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and state_dict is not None:
if Loading.is_load_module(child) and state_dict is not None:
if any(checking_key in item for item in state_dict):
Loading.load(
child,
Expand Down

0 comments on commit 0712e29

Please sign in to comment.