Skip to content

Commit 4730f04

Browse files
committed
change convert tools.
1 parent 4bfe3d9 commit 4730f04

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ xFasterTransformer provides a series of APIs, both of C++ and Python, for end us
4949
| ChatGLM3 | ✔ | ✔ | ✔ |
5050
| Llama | ✔ | ✔ | ✔ |
5151
| Llama2 | ✔ | ✔ | ✔ |
52-
| Deepseek-coder | ✔ | ✔ | ✔ |
5352
| Baichuan | ✔ | ✔ | ✔ |
5453
| QWen | ✔ | ✔ | ✔ |
5554
| SecLLM(YaRN-Llama) | ✔ | ✔ | ✔ |
5655
| Opt | ✔ | ✔ | ✔ |
56+
| Deepseek-coder | ✔ | ✔ | ✔ |
5757

5858
### DataType support list
5959

src/xfastertransformer/tools/llama_convert.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,15 @@ def split_and_convert(self, input_dir, output_dir, dtype, processes):
125125
config["llama"]["layernorm_type"] = "pre_layernorm"
126126
config["llama"]["activation_type"] = str(hf_config["hidden_act"])
127127
config["llama"]["rope_theta"] = str(hf_config.get("rope_theta", 10000))
128-
try:
129-
config["llama"]["scaling_factor"] = str(hf_config["rope_scaling"]["factor"])
130-
config["llama"]["rope_type"] = str(hf_config["rope_scaling"]["type"])
131-
except Exception as e:
132-
config["llama"]["scaling_factor"] = 1.0
133-
config["llama"]["rope_type"] = "null"
128+
129+
rope_scaling = hf_config.get("rope_scaling", None)
130+
if rope_scaling:
131+
config["llama"]["scaling_factor"] = str(rope_scaling.get("factor", 1.0))
132+
config["llama"]["rope_type"] = str(rope_scaling.get("type", "null"))
133+
else:
134+
config["llama"]["scaling_factor"] = str(1.0)
135+
config["llama"]["rope_type"] = str("null")
136+
134137
config["llama"]["has_post_decoder_layernorm"] = "1" if has_post_decoder_layernorm else "0"
135138
config["llama"]["vocab_size"] = str(hf_config["vocab_size"])
136139
config["llama"]["start_id"] = str(hf_config["bos_token_id"])
@@ -140,7 +143,6 @@ def split_and_convert(self, input_dir, output_dir, dtype, processes):
140143
config.write(configfile)
141144
except Exception as e:
142145
print("Fail to save the config in config.ini.", str(e))
143-
144146
hf_model_name_pattern = [
145147
"input_layernorm.weight",
146148
"attention.query_key_value.weight",

0 commit comments

Comments
 (0)