Skip to content

Commit cf99e6f

Browse files
mayank31398shawntanSsukriti
authored andcommitted
add shared experts for upcoming Granite 4.0 language models (huggingface#35894)
* Modular GraniteMoE with shared Experts. Signed-off-by: Shawn Tan <shawntan@ibm.com> * Modified * Import order. * Modified for style * Fix space. * Test * Remove extra granitemoe file. * New converted file and tests * Modified __init__ files. * Formatting. * Dummy PT objects * register granitemoe shared model Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix linting of a file Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix import in modeling file Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * update generated modeling file Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * add documentation Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * update docstrings Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * update generated modeling file Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix docstrings in config class Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * merge main Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> --------- Signed-off-by: Shawn Tan <shawntan@ibm.com> Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> Co-authored-by: Shawn Tan <shawntan@ibm.com> Co-authored-by: Shawn Tan <shawn@wtf.sg> Co-authored-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> Co-authored-by: Sukriti Sharma <Ssukriti@users.noreply.github.com>
1 parent 11fcfb2 commit cf99e6f

File tree

15 files changed

+2517
-0
lines changed

15 files changed

+2517
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@
461461
title: Granite
462462
- local: model_doc/granitemoe
463463
title: GraniteMoe
464+
- local: model_doc/granitemoeshared
465+
title: GraniteMoeShared
464466
- local: model_doc/granitevision
465467
title: GraniteVision
466468
- local: model_doc/helium

docs/source/en/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ Flax), PyTorch, and/or TensorFlow.
173173
| [GPTSAN-japanese](model_doc/gptsan-japanese) ||||
174174
| [Granite](model_doc/granite) ||||
175175
| [GraniteMoeMoe](model_doc/granitemoe) ||||
176+
| [GraniteMoeSharedMoe](model_doc/granitemoeshared) ||||
176177
| [Graphormer](model_doc/graphormer) ||||
177178
| [Grounding DINO](model_doc/grounding-dino) ||||
178179
| [GroupViT](model_doc/groupvit) ||||
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# GraniteMoeShared
18+
19+
## Overview
20+
21+
22+
The GraniteMoe model was proposed in [Power Scheduler: A Batch Size and Token Number Agnostic Learning Rate Scheduler](https://arxiv.org/abs/2408.13359) by Yikang Shen, Matthew Stallone, Mayank Mishra, Gaoyuan Zhang, Shawn Tan, Aditya Prasad, Adriana Meza Soria, David D. Cox and Rameswar Panda.
23+
24+
Additionally this class GraniteMoeSharedModel adds shared experts for Moe.
25+
26+
```python
27+
import torch
28+
from transformers import AutoModelForCausalLM, AutoTokenizer
29+
30+
model_path = "ibm-research/moe-7b-1b-active-shared-experts"
31+
tokenizer = AutoTokenizer.from_pretrained(model_path)
32+
33+
# drop device_map if running on CPU
34+
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
35+
model.eval()
36+
37+
# change input text as desired
38+
prompt = "Write a code to find the maximum value in a list of numbers."
39+
40+
# tokenize the text
41+
input_tokens = tokenizer(prompt, return_tensors="pt")
42+
# generate output tokens
43+
output = model.generate(**input_tokens, max_new_tokens=100)
44+
# decode output tokens into text
45+
output = tokenizer.batch_decode(output)
46+
# loop over the batch to print, in this example the batch size is 1
47+
for i in output:
48+
print(i)
49+
```
50+
51+
This HF implementation is contributed by [Mayank Mishra](https://huggingface.co/mayank-mishra), [Shawn Tan](https://huggingface.co/shawntan) and [Sukriti Sharma](https://huggingface.co/SukritiSharma).
52+
53+
54+
## GraniteMoeSharedConfig
55+
56+
[[autodoc]] GraniteMoeSharedConfig
57+
58+
## GraniteMoeSharedModel
59+
60+
[[autodoc]] GraniteMoeSharedModel
61+
- forward
62+
63+
## GraniteMoeSharedForCausalLM
64+
65+
[[autodoc]] GraniteMoeSharedForCausalLM
66+
- forward

docs/source/en/perf_infer_gpu_one.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ FlashAttention-2 is currently supported for the following architectures:
6060
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
6161
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
6262
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
63+
* [GraniteMoeShared](https://huggingface.co/docs/transformers/model_doc/granitemoeshared#transformers.GraniteMoeSharedModel)
6364
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
6465
* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model)
6566
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
@@ -266,6 +267,7 @@ For now, Transformers supports SDPA inference and training for the following arc
266267
* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model)
267268
* [I-JEPA](https://huggingface.co/docs/transformers/model_doc/ijepa#transformers.IJepaModel)
268269
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
270+
* [GraniteMoeShared](https://huggingface.co/docs/transformers/model_doc/granitemoeshared#transformers.GraniteMoeSharedModel)
269271
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
270272
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
271273
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)

src/transformers/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@
496496
"models.gptj": ["GPTJConfig"],
497497
"models.granite": ["GraniteConfig"],
498498
"models.granitemoe": ["GraniteMoeConfig"],
499+
"models.granitemoeshared": ["GraniteMoeSharedConfig"],
499500
"models.grounding_dino": [
500501
"GroundingDinoConfig",
501502
"GroundingDinoProcessor",
@@ -2539,6 +2540,14 @@
25392540
"GraniteMoePreTrainedModel",
25402541
]
25412542
)
2543+
_import_structure["models.granitemoeshared"].extend(
2544+
[
2545+
"GraniteMoeSharedForCausalLM",
2546+
"GraniteMoeSharedModel",
2547+
"GraniteMoeSharedPreTrainedModel",
2548+
]
2549+
)
2550+
25422551
_import_structure["models.grounding_dino"].extend(
25432552
[
25442553
"GroundingDinoForObjectDetection",
@@ -5605,6 +5614,7 @@
56055614
from .models.gptj import GPTJConfig
56065615
from .models.granite import GraniteConfig
56075616
from .models.granitemoe import GraniteMoeConfig
5617+
from .models.granitemoeshared import GraniteMoeSharedConfig
56085618
from .models.grounding_dino import (
56095619
GroundingDinoConfig,
56105620
GroundingDinoProcessor,
@@ -7479,6 +7489,11 @@
74797489
GraniteMoeModel,
74807490
GraniteMoePreTrainedModel,
74817491
)
7492+
from .models.granitemoeshared import (
7493+
GraniteMoeSharedForCausalLM,
7494+
GraniteMoeSharedModel,
7495+
GraniteMoeSharedPreTrainedModel,
7496+
)
74827497
from .models.grounding_dino import (
74837498
GroundingDinoForObjectDetection,
74847499
GroundingDinoModel,

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
gptj,
119119
granite,
120120
granitemoe,
121+
granitemoeshared,
121122
grounding_dino,
122123
groupvit,
123124
helium,

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
("gptsan-japanese", "GPTSanJapaneseConfig"),
138138
("granite", "GraniteConfig"),
139139
("granitemoe", "GraniteMoeConfig"),
140+
("granitemoeshared", "GraniteMoeSharedConfig"),
140141
("granitevision", "LlavaNextConfig"),
141142
("graphormer", "GraphormerConfig"),
142143
("grounding-dino", "GroundingDinoConfig"),
@@ -467,6 +468,7 @@
467468
("gptsan-japanese", "GPTSAN-japanese"),
468469
("granite", "Granite"),
469470
("granitemoe", "GraniteMoeMoe"),
471+
("granitemoeshared", "GraniteMoeSharedMoe"),
470472
("granitevision", "LLaVA-NeXT"),
471473
("graphormer", "Graphormer"),
472474
("grounding-dino", "Grounding DINO"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
133133
("granite", "GraniteModel"),
134134
("granitemoe", "GraniteMoeModel"),
135+
("granitemoeshared", "GraniteMoeSharedModel"),
135136
("graphormer", "GraphormerModel"),
136137
("grounding-dino", "GroundingDinoModel"),
137138
("groupvit", "GroupViTModel"),
@@ -526,6 +527,7 @@
526527
("gptj", "GPTJForCausalLM"),
527528
("granite", "GraniteForCausalLM"),
528529
("granitemoe", "GraniteMoeForCausalLM"),
530+
("granitemoeshared", "GraniteMoeSharedForCausalLM"),
529531
("helium", "HeliumForCausalLM"),
530532
("jamba", "JambaForCausalLM"),
531533
("jetmoe", "JetMoeForCausalLM"),
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_granitemoeshared import *
22+
from .modeling_granitemoeshared import *
23+
else:
24+
import sys
25+
26+
_file = globals()["__file__"]
27+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

0 commit comments

Comments
 (0)