Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qwen2MoE #29377

Merged
merged 46 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4f933bb
add support for qwen2 MoE models
Feb 28, 2024
8ad6c9e
update docs
Feb 28, 2024
fbce3b9
add support for qwen2 MoE models
Feb 28, 2024
c32b998
update docs
Feb 28, 2024
8274f89
Merge branch 'qwen2_moe' of https://github.com/bozheng-hit/transforme…
Feb 28, 2024
e44f700
update model name & test
Feb 29, 2024
b09e2ed
update readme
Feb 29, 2024
d5e99a6
update class names & readme & model_doc of Qwen2MoE.
Feb 29, 2024
1625b1f
update architecture name
Feb 29, 2024
051e19d
fix qwen2_moe tests
Feb 29, 2024
307d9de
use Qwen2Tokenizer instead of Qwen2MoeTokenizer
Mar 1, 2024
4d80bf8
update modeling_qwen2_moe.py
Mar 1, 2024
8b6d57b
fix model architecture
Mar 9, 2024
b9c2803
fix qwen2_moe tests
Feb 29, 2024
f8e1819
use Qwen2Tokenizer instead of Qwen2MoeTokenizer
Mar 1, 2024
e4b8445
update modeling_qwen2_moe.py
Mar 1, 2024
8d74bb0
fix model architecture
Mar 9, 2024
a50a208
fix style
Mar 10, 2024
a04c698
fix test when there are sparse and non sparse layers
Mar 10, 2024
dc53a8d
fixup
Mar 21, 2024
8f55aa5
Update README.md
bozheng-hit Mar 21, 2024
6a06f8e
fix up
Mar 21, 2024
bf11227
fixup
Mar 22, 2024
e3038db
fixup
Mar 23, 2024
5c627d3
add archive back
Mar 23, 2024
765ebf5
add support for qwen2 MoE models
Feb 28, 2024
1c973fb
update docs
Feb 28, 2024
0841722
update model name & test
Feb 29, 2024
4c0b2b1
update readme
Feb 29, 2024
8958743
update class names & readme & model_doc of Qwen2MoE.
Feb 29, 2024
1e099c5
update architecture name
Feb 29, 2024
4906cdf
fix qwen2_moe tests
Feb 29, 2024
82729ec
use Qwen2Tokenizer instead of Qwen2MoeTokenizer
Mar 1, 2024
a3aa52d
update modeling_qwen2_moe.py
Mar 1, 2024
0686cc6
fix model architecture
Mar 9, 2024
c074021
fixup
Mar 21, 2024
2484604
fix qwen2_moe tests
Feb 29, 2024
5d1ed37
use Qwen2Tokenizer instead of Qwen2MoeTokenizer
Mar 1, 2024
27afcd5
fix style
Mar 10, 2024
0d155e9
fix test when there are sparse and non sparse layers
Mar 10, 2024
46b0918
fixup
Mar 23, 2024
45219a1
add archive back
Mar 23, 2024
cf61e7e
fixup
Mar 25, 2024
3b9f3a8
fix integration test
Mar 26, 2024
4077877
fixup
Mar 26, 2024
4d931f0
Merge branch 'main' into qwen2_moe
bozheng-hit Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixup
  • Loading branch information
bozheng-hit committed Mar 25, 2024
commit c074021466255b7393ef938f2c3189d9b2795dce
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[PVTv2](https://huggingface.co/docs/transformers/model_doc/pvt_v2)** (from Shanghai AI Laboratory, Nanjing University, The University of Hong Kong etc.) released with the paper [PVT v2: Improved Baselines with Pyramid Vision Transformer](https://arxiv.org/abs/2106.13797) by Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao.
1. **[QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius.
1. **[Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2)** (from the Qwen team, Alibaba Group) released with the paper [Qwen Technical Report](https://arxiv.org/abs/2309.16609) by Jinze Bai, Shuai Bai, Yunfei Chu, Zeyu Cui, Kai Dang, Xiaodong Deng, Yang Fan, Wenbin Ge, Yu Han, Fei Huang, Binyuan Hui, Luo Ji, Mei Li, Junyang Lin, Runji Lin, Dayiheng Liu, Gao Liu, Chengqiang Lu, Keming Lu, Jianxin Ma, Rui Men, Xingzhang Ren, Xuancheng Ren, Chuanqi Tan, Sinan Tan, Jianhong Tu, Peng Wang, Shijie Wang, Wei Wang, Shengguang Wu, Benfeng Xu, Jin Xu, An Yang, Hao Yang, Jian Yang, Shusheng Yang, Yang Yao, Bowen Yu, Hongyi Yuan, Zheng Yuan, Jianwei Zhang, Xingxuan Zhang, Yichang Zhang, Zhenru Zhang, Chang Zhou, Jingren Zhou, Xiaohuan Zhou and Tianhang Zhu.
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with the paper [blog post](https://qwenlm.github.io/blog/qwen1.5/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[Qwen2MoE](https://huggingface.co/docs/transformers/main/model_doc/qwen2_moe)** (from the Qwen team, Alibaba Group) released with [blog post](https://qwenlm.github.io/blog/qwen1.5/) by Bo Zheng, Dayiheng Liu, Rui Men, Junyang Lin, Zhou San, Bowen Yu, An Yang, Mingfeng Xue, Fei Huang, Binyuan Hui, Mei Li, Tianyu Liu, Xingzhang Ren, Xuancheng Ren, Kexin Yang, Chang Zhou, Jingren Zhou.
1. **[RAG](https://huggingface.co/docs/transformers/model_doc/rag)** (from Facebook) released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Qwen2MoeConfig(PretrainedConfig):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
expert_interval (`int`, *optional*, defaults to 1):
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert.
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
expert_interval=1,
decoder_sparse_step=1,
moe_intermediate_size=1408,
shared_expert_intermediate_size=5632,
num_experts_per_tok=4,
Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(
self.attention_dropout = attention_dropout

# MoE arguments
self.expert_interval = expert_interval
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
Expand Down
13 changes: 4 additions & 9 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,14 +871,9 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

if config.num_experts > 0 and (layer_idx + 1) % config.expert_interval == 0:
if config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0:
self.mlp = Qwen2MoeSparseMoeBlock(config)
else:
self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)
Expand Down Expand Up @@ -938,10 +933,10 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

if isinstance(self.mlp, Qwen2MoeSparseMoeBlock):
hidden_states, router_logits = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states, router_logits = hidden_states
else:
hidden_states = self.mlp(hidden_states)
router_logits = None

hidden_states = residual + hidden_states
Expand Down