|
10 | 10 |
|
11 | 11 | # Architecture -> (module, class).
|
12 | 12 | _MODELS = {
|
13 |
| - "AquilaModel": ("llama", "LlamaForCausalLM"), |
14 |
| - "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 |
| 13 | + "AquilaModel": ("aquila", "AquilaForCausalLM"), |
| 14 | + "AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2 |
15 | 15 | "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
16 | 16 | "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
|
17 | 17 | "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
|
24 | 24 | "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
25 | 25 | "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
26 | 26 | "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
27 |
| - "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), |
| 27 | + "InternLMForCausalLM": ("internlm", "InternLMForCausalLM"), |
28 | 28 | "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
29 | 29 | "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
30 | 30 | # For decapoda-research/llama-*
|
31 | 31 | "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
32 |
| - "MistralForCausalLM": ("llama", "LlamaForCausalLM"), |
| 32 | + "MistralForCausalLM": ("mistral", "MistralForCausalLM"), |
33 | 33 | "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
34 | 34 | "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
35 | 35 | # transformers's mpt class has lower case
|
|
41 | 41 | "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
42 | 42 | "RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
43 | 43 | "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
| 44 | + "YiForCausalLM": ("yi", "YiForCausalLM") |
44 | 45 | }
|
45 | 46 |
|
46 | 47 | # Models not supported by ROCm.
|
|
0 commit comments