Skip to content

Commit

Permalink
[LLM] Reconstruct fused transformer layers (#7186)
Browse files Browse the repository at this point in the history
* reconstruct fused_transformer_layers

* delete origin class

* code refine
  • Loading branch information
RichardWooSJTU authored Oct 12, 2023
1 parent d0c85df commit 64f00f0
Show file tree
Hide file tree
Showing 7 changed files with 545 additions and 349 deletions.
8 changes: 6 additions & 2 deletions paddlenlp/experimental/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from paddlenlp_ops import get_padding_offset

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedMultiTransformer,
FusedMultiTransformerBase,
FusedMultiTransformerConfig,
)
from paddlenlp.experimental.transformers.generation_utils import (
GenerationInferenceModel,
Expand Down Expand Up @@ -112,7 +113,8 @@ def __init__(self, config):
ffn1_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn1_bias".format(i)) for i in range(config.n_layer)]
ffn2_weight_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_weight".format(i)) for i in range(config.n_layer)]
ffn2_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(config.n_layer)]
self.transformer_block = FusedMultiTransformer(

transformer_config = FusedMultiTransformerConfig(
self.embed_dim,
self.n_head,
4 * self.embed_dim,
Expand All @@ -133,6 +135,8 @@ def __init__(self, config):
ffn2_weight_attrs=ffn2_weight_attrs,
ffn2_bias_attrs=ffn2_bias_attrs,
)

self.transformer_block = FusedMultiTransformerBase(transformer_config)
self.cache_kvs = []

# Final Layer Norm
Expand Down
7 changes: 5 additions & 2 deletions paddlenlp/experimental/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from paddlenlp_ops import get_padding_offset

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedMultiTransformer,
FusedMultiTransformerBase,
FusedMultiTransformerConfig,
)
from paddlenlp.experimental.transformers.generation_utils import (
GenerationInferenceModel,
Expand Down Expand Up @@ -183,7 +184,8 @@ def __init__(self, config: ChatGLMConfig):
]
ffn2_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(config.num_layers)]
alpha = (2 * self.config.num_hidden_layers) ** 0.5
self.transformer_block = FusedMultiTransformer(

transformer_config = FusedMultiTransformerConfig(
config.hidden_size,
config.num_attention_heads,
4 * config.hidden_size,
Expand All @@ -209,6 +211,7 @@ def __init__(self, config: ChatGLMConfig):
norm_type="layernorm",
use_neox_rotary_style=True,
)
self.transformer_block = FusedMultiTransformerBase(transformer_config)

def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
Expand Down
Loading

0 comments on commit 64f00f0

Please sign in to comment.