@@ -37,7 +37,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
3737        super ().__init__ ()
3838        self .dim  =  config .dim 
3939        self .n_heads  =  config .n_heads 
40-         self .head_dim  =  config .dim   //   config . n_heads 
40+         self .head_dim  =  config .head_dim 
4141        self .n_kv_heads  =  config .n_kv_heads 
4242        self .num_key_value_groups  =  config .n_heads  //  self .n_kv_heads 
4343        self .max_seq_len  =  config .max_seq_len 
@@ -304,7 +304,7 @@ def __init__(
304304    ):
305305        super ().__init__ ()
306306        self .dim  =  config .dim 
307-         self .head_dim  =  config .dim   //   config . n_heads 
307+         self .head_dim  =  config .head_dim 
308308        self .max_batch_size  =  config .max_batch_size 
309309        self .max_seq_len  =  config .max_seq_len 
310310        self .n_heads  =  config .n_heads 
@@ -328,9 +328,11 @@ def __init__(
328328        self .output  =  nn .Linear (config .dim , config .vocab_size , bias = False )
329329        self .tok_embeddings  =  nn .Embedding (config .vocab_size , config .dim )
330330        freqs_cos , freqs_sin  =  precompute_freqs_cis (
331-             config .dim   //   config . n_heads ,
331+             config .head_dim ,
332332            config .max_seq_len ,
333333            config .rope_freq_base ,
334+             config .use_scaled_rope ,
335+             config .rope_scale_factor ,
334336        )
335337        self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
336338        self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
0 commit comments