File tree Expand file tree Collapse file tree 1 file changed +18
-3
lines changed Expand file tree Collapse file tree 1 file changed +18
-3
lines changed Original file line number Diff line number Diff line change @@ -686,9 +686,24 @@ def recurse_elems(elem: Any):
686686 config_dict ["hidden_act" ] = config_dict .get ("activation" , "silu" )
687687 config_dict ["tie_word_embeddings" ] = config_dict .get (
688688 "tie_embeddings" , False )
689- config_dict ["max_seq_len" ] = config_dict .get ("max_seq_len" , 128_000 )
690- config_dict ["max_position_embeddings" ] = config_dict .get (
691- "max_position_embeddings" , 128_000 )
689+
690+ if config_dict .get ("max_position_embeddings" ) is None :
691+ max_position_embeddings = 128_000
692+ try :
693+ trust_remote_code_val = kwargs .get ("trust_remote_code" , False )
694+ hf_config = get_config (model = model ,
695+ trust_remote_code = trust_remote_code_val ,
696+ revision = revision ,
697+ config_format = ConfigFormat .HF )
698+ if hf_value := hf_config .get_text_config ().max_position_embeddings :
699+ max_position_embeddings = hf_value
700+ except Exception as e :
701+ logger .warning (
702+ "The params.json file is missing 'max_position_embeddings'"
703+ " and could not get a value from the HF config."
704+ " Defaulting to 128000" ,
705+ exc_info = e )
706+ config_dict ["max_position_embeddings" ] = max_position_embeddings
692707
693708 if config_dict .get ("quantization" ) is not None :
694709 quantization = config_dict .get ("quantization" , {})
You can’t perform that action at this time.
0 commit comments