Skip to content

Commit

Permalink
fix MegatronLayerPolicy to be compatible with the newest ParallelTran…
Browse files Browse the repository at this point in the history
…sformerLayer (#4236)

Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
  • Loading branch information
dc3671 and RezaYazdaniAminabadi authored Aug 30, 2023
1 parent 5dbc531 commit 6cbf666
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions deepspeed/module_inject/containers/megatron_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,21 @@ def __init__(self, client_module, inference=True):
try:
from megatron.model.transformer import ParallelTransformerLayer
MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer
MegatronLayerPolicy.version = 1
except ImportError:
MegatronLayerPolicy._orig_layer_class = None

def get_hidden_heads(self):
return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE
if MegatronLayerPolicy.version == 0:
return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE
else:
return self.client_module.self_attention.query_key_value.weight.shape[1], \
self.client_module.self_attention.num_attention_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE

def attention(self, enable_training=False):
if self.inference:
Expand Down

0 comments on commit 6cbf666

Please sign in to comment.