Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions paddleformers/transformers/gpt_oss/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import math
from functools import partial
from typing import Optional, Tuple, Union

import paddle
Expand Down Expand Up @@ -633,6 +634,53 @@ class GptOssPreTrainedModel(PretrainedModel):
keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"]
transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

@classmethod
def _get_tensor_parallel_mappings(cls, config: GptOssConfig, is_split=True):
from ..conversion_utils import split_or_merge_func

fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)

def get_tensor_parallel_split_mappings(num_layers, num_experts):
final_actions = {}

base_actions = {
"lm_head.weight": partial(fn, is_column=False),
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
}

if not config.vocab_size % config.tensor_parallel_degree == 0:
base_actions.pop("lm_head.weight")
base_actions.pop("embed_tokens.weight")
base_actions["layers.0.self_attn.sinks"] = partial(fn, is_column=False)
# Column Linear
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)

# if we have enough num_key_value_heads to split, then split it.
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)

for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action

return final_actions

mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers, config.num_experts)
return mappings

@classmethod
def _gen_aoa_config(cls, config: GptOssConfig):
model_prefix = "" if cls == cls.base_model_class else "model."
Expand Down