@@ -112,6 +112,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
112
112
class FalconRotaryEmbedding (nn .Module ):
113
113
def __init__ (self , config : FalconConfig , device = None ):
114
114
super ().__init__ ()
115
+ self .rope_kwargs = {}
115
116
# BC: "rope_type" was originally "type"
116
117
if hasattr (config , "rope_scaling" ) and config .rope_scaling is not None :
117
118
self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
@@ -123,7 +124,7 @@ def __init__(self, config: FalconConfig, device=None):
123
124
self .config = config
124
125
self .rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
125
126
126
- inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device )
127
+ inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device , ** self . rope_kwargs )
127
128
self .register_buffer ("inv_freq" , inv_freq , persistent = False )
128
129
self .original_inv_freq = self .inv_freq
129
130
@@ -135,7 +136,9 @@ def _dynamic_frequency_update(self, position_ids, device):
135
136
"""
136
137
seq_len = torch .max (position_ids ) + 1
137
138
if seq_len > self .max_seq_len_cached : # growth
138
- inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device , seq_len = seq_len )
139
+ inv_freq , self .attention_scaling = self .rope_init_fn (
140
+ self .config , device , seq_len = seq_len , ** self .rope_kwargs
141
+ )
139
142
self .register_buffer ("inv_freq" , inv_freq , persistent = False ) # TODO joao: may break with compilation
140
143
self .max_seq_len_cached = seq_len
141
144
0 commit comments