Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qwen2MoE #29377

Merged
merged 46 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4f933bb
add support for qwen2 MoE models
Feb 28, 2024
8ad6c9e
update docs
Feb 28, 2024
fbce3b9
add support for qwen2 MoE models
Feb 28, 2024
c32b998
update docs
Feb 28, 2024
8274f89
Merge branch 'qwen2_moe' of https://github.com/bozheng-hit/transforme…
Feb 28, 2024
e44f700
update model name & test
Feb 29, 2024
b09e2ed
update readme
Feb 29, 2024
d5e99a6
update class names & readme & model_doc of Qwen2MoE.
Feb 29, 2024
1625b1f
update architecture name
Feb 29, 2024
051e19d
fix qwen2_moe tests
Feb 29, 2024
307d9de
use Qwen2Tokenizer instead of Qwen2MoeTokenizer
Mar 1, 2024
4d80bf8
update modeling_qwen2_moe.py
Mar 1, 2024
8b6d57b
fix model architecture
Mar 9, 2024
b9c2803
fix qwen2_moe tests
Feb 29, 2024
f8e1819
use Qwen2Tokenizer instead of Qwen2MoeTokenizer
Mar 1, 2024
e4b8445
update modeling_qwen2_moe.py
Mar 1, 2024
8d74bb0
fix model architecture
Mar 9, 2024
a50a208
fix style
Mar 10, 2024
a04c698
fix test when there are sparse and non sparse layers
Mar 10, 2024
dc53a8d
fixup
Mar 21, 2024
8f55aa5
Update README.md
bozheng-hit Mar 21, 2024
6a06f8e
fix up
Mar 21, 2024
bf11227
fixup
Mar 22, 2024
e3038db
fixup
Mar 23, 2024
5c627d3
add archive back
Mar 23, 2024
765ebf5
add support for qwen2 MoE models
Feb 28, 2024
1c973fb
update docs
Feb 28, 2024
0841722
update model name & test
Feb 29, 2024
4c0b2b1
update readme
Feb 29, 2024
8958743
update class names & readme & model_doc of Qwen2MoE.
Feb 29, 2024
1e099c5
update architecture name
Feb 29, 2024
4906cdf
fix qwen2_moe tests
Feb 29, 2024
82729ec
use Qwen2Tokenizer instead of Qwen2MoeTokenizer
Mar 1, 2024
a3aa52d
update modeling_qwen2_moe.py
Mar 1, 2024
0686cc6
fix model architecture
Mar 9, 2024
c074021
fixup
Mar 21, 2024
2484604
fix qwen2_moe tests
Feb 29, 2024
5d1ed37
use Qwen2Tokenizer instead of Qwen2MoeTokenizer
Mar 1, 2024
27afcd5
fix style
Mar 10, 2024
0d155e9
fix test when there are sparse and non sparse layers
Mar 10, 2024
46b0918
fixup
Mar 23, 2024
45219a1
add archive back
Mar 23, 2024
cf61e7e
fixup
Mar 25, 2024
3b9f3a8
fix integration test
Mar 26, 2024
4077877
fixup
Mar 26, 2024
4d931f0
Merge branch 'main' into qwen2_moe
bozheng-hit Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add support for qwen2 MoE models
  • Loading branch information
bozheng-hit committed Mar 25, 2024
commit 765ebf5314fc9a67df025a85892d4f2910498832
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7649,6 +7649,12 @@
Qwen2Model,
Qwen2PreTrainedModel,
)
from .models.qwen2_moe import (
Qwen2MoEForCausalLM,
Qwen2MoEForSequenceClassification,
Qwen2MoEModel,
Qwen2MoEPreTrainedModel,
)
from .models.rag import (
RagModel,
RagPreTrainedModel,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
("pvt_v2", "PvtV2Config"),
("qdqbert", "QDQBertConfig"),
("qwen2", "Qwen2Config"),
("qwen2_moe", "Qwen2MoEConfig"),
("rag", "RagConfig"),
("realm", "RealmConfig"),
("reformer", "ReformerConfig"),
Expand Down Expand Up @@ -428,6 +429,7 @@
("pvt_v2", "PVT_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("qwen2", "QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("qwen2_moe", "QWEN2MOE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("regnet", "REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Expand Down Expand Up @@ -687,6 +689,7 @@
("pvt_v2", "PVTv2"),
("qdqbert", "QDQBert"),
("qwen2", "Qwen2"),
("qwen2_moe", "Qwen2MoE"),
("rag", "RAG"),
("realm", "REALM"),
("reformer", "Reformer"),
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
("pvt_v2", "PvtV2Model"),
("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"),
("qwen2_moe", "Qwen2MoEModel"),
("reformer", "ReformerModel"),
("regnet", "RegNetModel"),
("rembert", "RemBertModel"),
Expand Down Expand Up @@ -467,6 +468,7 @@
("prophetnet", "ProphetNetForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"),
("qwen2", "Qwen2ForCausalLM"),
("qwen2_moe", "Qwen2MoEForCausalLM"),
("reformer", "ReformerModelWithLMHead"),
("rembert", "RemBertForCausalLM"),
("roberta", "RobertaForCausalLM"),
Expand Down Expand Up @@ -871,6 +873,7 @@
("plbart", "PLBartForSequenceClassification"),
("qdqbert", "QDQBertForSequenceClassification"),
("qwen2", "Qwen2ForSequenceClassification"),
("qwen2_moe", "Qwen2MoEForSequenceClassification"),
("reformer", "ReformerForSequenceClassification"),
("rembert", "RemBertForSequenceClassification"),
("roberta", "RobertaForSequenceClassification"),
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,13 @@
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"qwen2_moe",
(
"Qwen2MoETokenizer",
"Qwen2MoETokenizerFast" if is_tokenizers_available() else None,
),
),
("rag", ("RagTokenizer", None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
(
Expand Down
80 changes: 80 additions & 0 deletions src/transformers/models/qwen2_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tokenizers_available,
is_torch_available,
)


_import_structure = {
"configuration_qwen2_moe": ["QWEN2MOE_PRETRAINED_CONFIG_ARCHIVE_MAP", "Qwen2MoEConfig"],
"tokenization_qwen2_moe": ["Qwen2MoETokenizer"],
}

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_qwen2_moe_fast"] = ["Qwen2MoETokenizerFast"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_qwen2_moe"] = [
"Qwen2MoEForCausalLM",
"Qwen2MoEModel",
"Qwen2MoEPreTrainedModel",
"Qwen2MoEForSequenceClassification",
]


if TYPE_CHECKING:
from .configuration_qwen2_moe import QWEN2MOE_PRETRAINED_CONFIG_ARCHIVE_MAP, Qwen2MoEConfig
from .tokenization_qwen2_moe import Qwen2MoETokenizer

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_qwen2_moe_fast import Qwen2MoETokenizerFast

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_qwen2_moe import (
Qwen2MoEForCausalLM,
Qwen2MoEForSequenceClassification,
Qwen2MoEModel,
Qwen2MoEPreTrainedModel,
)


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
180 changes: 180 additions & 0 deletions src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Qwen2MoE model configuration"""

from ...configuration_utils import PretrainedConfig
from ...utils import logging

logger = logging.get_logger(__name__)

QWEN2MOE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"Qwen/Qwen2MoE-14B-beta": "https://huggingface.co/Qwen/Qwen2MoE-14B-beta/resolve/main/config.json",
}


class Qwen2MoEConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2MoEModel`]. It is used to instantiate a
Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2MoE-14B-beta [Qwen/Qwen2MoE-14B-beta](https://huggingface.co/Qwen/Qwen2MoE-14B-beta).

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.


Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2MoEModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
expert_interval (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert.
shared_expert_intermediate_size (`int`, *optional*, defaults to 5632):
Intermediate size of the shared expert.
shared_expert_gate (`bool`, *optional*, defaults to `True`):
Whether to use gating mechinism for the shared expert.
num_experts_per_tok (`int`, *optional*, defaults to 4):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 60):
Number of routed experts.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.

```python
>>> from transformers import Qwen2MoEModel, Qwen2MoEConfig

>>> # Initializing a Qwen2MoE style configuration
>>> configuration = Qwen2MoEConfig()

>>> # Initializing a model from the Qwen2MoE-14B style configuration
>>> model = Qwen2MoEModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "qwen2_moe"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
expert_interval=1,
moe_intermediate_size=1408,
shared_expert_intermediate_size=5632,
shared_expert_gate=True,
num_experts_per_tok=4,
num_experts=60,
norm_topk_prob=False,
output_router_logits=False,
router_aux_loss_coef=0.001,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
bozheng-hit marked this conversation as resolved.
Show resolved Hide resolved

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout

# MoE arguments
self.expert_interval = expert_interval
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.shared_expert_gate = shared_expert_gate
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef

super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading