Skip to content

Commit fa2906f

Browse files
authored
update RotaryEmbedding for multiple models (#3076)
1 parent 36d37a0 commit fa2906f

File tree

11 files changed

+22
-0
lines changed

11 files changed

+22
-0
lines changed

paddleformers/transformers/ernie4_5/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def apply_fused_rope(query_states, key_states, rope_theta):
114114
class Ernie4_5RotaryEmbedding(nn.Layer):
115115
def __init__(self, config):
116116
super().__init__()
117+
self.max_seq_len_cached = config.max_position_embeddings
118+
self.original_max_seq_len = config.max_position_embeddings
117119
self.config = config
118120
self.head_dim = config.head_dim
119121
self.base = config.rope_theta

paddleformers/transformers/ernie4_5_moe/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def mtp_hidden_states_set_zero(hidden_states, inbatch_pack_offset):
9292
class Ernie4_5_MoeRotaryEmbedding(nn.Layer):
9393
def __init__(self, config):
9494
super().__init__()
95+
self.max_seq_len_cached = config.max_position_embeddings
96+
self.original_max_seq_len = config.max_position_embeddings
9597
self.config = config
9698
self.head_dim = config.head_dim
9799
self.base = config.rope_theta

paddleformers/transformers/gemma3_text/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def forward(self, x):
123123
class Gemma3RotaryEmbedding(nn.Layer):
124124
def __init__(self, config):
125125
super().__init__()
126+
self.max_seq_len_cached = config.max_position_embeddings
127+
self.original_max_seq_len = config.max_position_embeddings
126128
self.config = config
127129
base = config.rope_theta
128130
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0

paddleformers/transformers/glm4_moe/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,8 @@ def _gen_inv_aoa_config(cls, config: Glm4MoeConfig):
12011201
class Glm4MoeRotaryEmbedding(nn.Layer):
12021202
def __init__(self, config: Glm4MoeConfig, device=None):
12031203
super().__init__()
1204+
self.max_seq_len_cached = config.max_position_embeddings
1205+
self.original_max_seq_len = config.max_position_embeddings
12041206
self.config = config
12051207
base = config.rope_theta
12061208
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0

paddleformers/transformers/llama/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ def _compute_llama3_parameters(config):
289289
class LlamaRotaryEmbedding(nn.Layer):
290290
def __init__(self, config):
291291
super().__init__()
292+
self.max_seq_len_cached = config.max_position_embeddings
293+
self.original_max_seq_len = config.max_position_embeddings
292294
self.config = config
293295
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
294296

paddleformers/transformers/phi3/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ def forward(
255255
class Phi3RotaryEmbedding(nn.Layer):
256256
def __init__(self, config: Phi3Config, device=None):
257257
super().__init__()
258+
self.max_seq_len_cached = config.max_position_embeddings
259+
self.original_max_seq_len = config.max_position_embeddings
258260
self.config = config
259261
base = config.rope_theta
260262
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0

paddleformers/transformers/qwen2/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,8 @@ def _gen_inv_aoa_config(cls, config: Qwen2Config):
457457
class Qwen2RotaryEmbedding(nn.Layer):
458458
def __init__(self, config: Qwen2Config):
459459
super().__init__()
460+
self.max_seq_len_cached = config.max_position_embeddings
461+
self.original_max_seq_len = config.max_position_embeddings
460462
self.config = config
461463
base = config.rope_theta
462464
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0

paddleformers/transformers/qwen2_5_vl/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,8 @@ class Qwen2_5_VLRotaryEmbedding(nn.Layer):
705705

706706
def __init__(self, config: Qwen2_5_VLTextConfig):
707707
super().__init__()
708+
self.max_seq_len_cached = config.max_position_embeddings
709+
self.original_max_seq_len = config.max_position_embeddings
708710
self.config = config
709711
base = config.rope_theta
710712
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0

paddleformers/transformers/qwen2_moe/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,8 @@ def forward(
436436
class Qwen2MoeRotaryEmbedding(nn.Layer):
437437
def __init__(self, config: Qwen2MoeConfig):
438438
super().__init__()
439+
self.max_seq_len_cached = config.max_position_embeddings
440+
self.original_max_seq_len = config.max_position_embeddings
439441
self.config = config
440442
base = config.rope_theta
441443
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0

paddleformers/transformers/qwen3/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,8 @@ def _gen_inv_aoa_config(cls, config: Qwen3Config):
479479
class Qwen3RotaryEmbedding(nn.Layer):
480480
def __init__(self, config: Qwen3Config):
481481
super().__init__()
482+
self.max_seq_len_cached = config.max_position_embeddings
483+
self.original_max_seq_len = config.max_position_embeddings
482484
self.config = config
483485
base = config.rope_theta
484486
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0

0 commit comments

Comments
 (0)