Skip to content

Commit b9de90b

Browse files
authored
Merge branch 'main' into benchmarking_script
2 parents d53f7a5 + 83314c7 commit b9de90b

File tree

6 files changed

+140
-3
lines changed

6 files changed

+140
-3
lines changed

torchchat/model.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ class TransformerArgs:
287287
feed_forward_bias: bool = False
288288
# Whether or not to tie the input word embeddings to the output
289289
tie_word_embeddings: bool = False
290+
# Granite architecture multipliers
291+
embedding_multiplier: Optional[float] = None
292+
attention_multiplier: Optional[float] = None
293+
residual_multiplier: Optional[float] = None
294+
logits_scaling: Optional[float] = None
290295

291296
def __post_init__(self):
292297
if self.n_local_heads == -1:
@@ -723,13 +728,20 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int
723728
if self.tok_embeddings:
724729
x = self.tok_embeddings(x)
725730

731+
# For Granite architectures
732+
if self.config.embedding_multiplier:
733+
x = x * self.config.embedding_multiplier
734+
726735
for _, layer in self.layers.items():
727736
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)
728737

729738
if self.norm:
730739
x = self.norm(x)
731740
if self.output:
732741
x = self.output(x)
742+
# For granite architectures
743+
if self.config.logits_scaling:
744+
x = x / self.config.logits_scaling
733745
# print(f"output shape: {x.shape}")
734746
return x
735747

@@ -741,6 +753,12 @@ def __init__(self, config: TransformerArgs) -> None:
741753
self.feed_forward = FeedForward(config)
742754
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
743755
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
756+
# None for llama architecture, set for granite architectures
757+
self.residual_multiplier = (
758+
config.residual_multiplier
759+
if config.residual_multiplier is not None
760+
else 1.0
761+
)
744762

745763
def distribute(self, device_mesh: DeviceMesh):
746764
self.attention.distribute(device_mesh)
@@ -751,8 +769,8 @@ def forward(
751769
) -> Tensor:
752770
h = x + self.attention(
753771
self.attention_norm(x), freqs_cis, mask, input_pos, cache_lane=cache_lane
754-
)
755-
out = h + self.feed_forward(self.ffn_norm(h))
772+
) * self.residual_multiplier
773+
out = h + self.feed_forward(self.ffn_norm(h)) * self.residual_multiplier
756774
return out
757775

758776

@@ -779,6 +797,7 @@ def __init__(self, config: TransformerArgs):
779797
self.head_dim = config.head_dim
780798
self.n_local_heads = config.n_local_heads
781799
self.dim = config.dim
800+
self.attention_scale = config.attention_multiplier
782801
self._register_load_state_dict_pre_hook(self.load_hook)
783802

784803
def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
@@ -875,7 +894,16 @@ def forward(
875894

876895
k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
877896
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
878-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
897+
y = F.scaled_dot_product_attention(
898+
query=q,
899+
key=k,
900+
value=v,
901+
attn_mask=mask,
902+
dropout_p=0.0,
903+
# This is None (default) for llama architecture and set for granite
904+
# architectures
905+
scale=self.attention_scale,
906+
)
879907

880908
# -1 = self.dim
881909
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

torchchat/model_config/models.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,33 @@
178178
"distribution_path": "ibm-granite/granite-8b-code-instruct-128k",
179179
"transformer_params_key": "Granite-8B-Code",
180180
"tokenizer_file": "tokenizer.json"
181+
},
182+
"ibm-granite/granite-3.0-2b-instruct": {
183+
"aliases": ["granite3-2b", "granite3"],
184+
"distribution_channel": "HuggingFaceSnapshot",
185+
"distribution_path": "ibm-granite/granite-3.0-2b-instruct",
186+
"transformer_params_key": "Granite-3.0-2B-Instruct",
187+
"tokenizer_file": "tokenizer.json"
188+
},
189+
"ibm-granite/granite-3.0-8b-instruct": {
190+
"aliases": ["granite3-8b"],
191+
"distribution_channel": "HuggingFaceSnapshot",
192+
"distribution_path": "ibm-granite/granite-3.0-8b-instruct",
193+
"transformer_params_key": "Granite-3.0-8B-Instruct",
194+
"tokenizer_file": "tokenizer.json"
195+
},
196+
"ibm-granite/granite-3.1-2b-instruct": {
197+
"aliases": ["granite3.1-2b", "granite3.1"],
198+
"distribution_channel": "HuggingFaceSnapshot",
199+
"distribution_path": "ibm-granite/granite-3.1-2b-instruct",
200+
"transformer_params_key": "Granite-3.1-2B-Instruct",
201+
"tokenizer_file": "tokenizer.json"
202+
},
203+
"ibm-granite/granite-3.1-8b-instruct": {
204+
"aliases": ["granite3.1-8b"],
205+
"distribution_channel": "HuggingFaceSnapshot",
206+
"distribution_path": "ibm-granite/granite-3.1-8b-instruct",
207+
"transformer_params_key": "Granite-3.1-8B-Instruct",
208+
"tokenizer_file": "tokenizer.json"
181209
}
182210
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"block_size": 8192,
3+
"dim": 2048,
4+
"hidden_dim": 8192,
5+
"n_heads": 32,
6+
"n_local_heads": 8,
7+
"n_layers": 40,
8+
"rope_base": 10000,
9+
"vocab_size": 49155,
10+
"use_hf_tokenizer": true,
11+
"tokenizer_prepend_bos": false,
12+
"norm_eps": 0.00001,
13+
"rope_scaling": null,
14+
"attention_bias": false,
15+
"feed_forward_bias": false,
16+
"tie_word_embeddings": true,
17+
"embedding_multiplier": 12.0,
18+
"attention_multiplier": 0.015625,
19+
"residual_multiplier": 0.22,
20+
"logits_scaling": 8.0
21+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"attention_multiplier": 0.0078125,
3+
"embedding_multiplier": 12.0,
4+
"dim": 4096,
5+
"block_size": 12800,
6+
"hidden_dim": 12800,
7+
"logits_scaling": 16.0,
8+
"n_heads": 32,
9+
"n_layers": 40,
10+
"n_local_heads": 8,
11+
"residual_multiplier": 0.22,
12+
"norm_eps": 1e-05,
13+
"rope_base": 10000,
14+
"tie_word_embeddings": true,
15+
"vocab_size": 49155,
16+
"use_hf_tokenizer": true,
17+
"tokenizer_prepend_bos": false,
18+
"attention_bias": false,
19+
"feed_forward_bias": false
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"attention_multiplier": 0.015625,
3+
"embedding_multiplier": 12.0,
4+
"dim": 2048,
5+
"block_size": 8192,
6+
"hidden_dim": 8192,
7+
"logits_scaling": 8.0,
8+
"n_heads": 32,
9+
"n_layers": 40,
10+
"n_local_heads": 8,
11+
"residual_multiplier": 0.22,
12+
"norm_eps": 1e-05,
13+
"rope_base": 5000000.0,
14+
"tie_word_embeddings": true,
15+
"vocab_size": 49155,
16+
"use_hf_tokenizer": true,
17+
"tokenizer_prepend_bos": false,
18+
"attention_bias": false,
19+
"feed_forward_bias": false
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"attention_multiplier": 0.0078125,
3+
"embedding_multiplier": 12.0,
4+
"dim": 4096,
5+
"block_size": 12800,
6+
"hidden_dim": 12800,
7+
"logits_scaling": 16.0,
8+
"n_heads": 32,
9+
"n_layers": 40,
10+
"n_local_heads": 8,
11+
"residual_multiplier": 0.22,
12+
"norm_eps": 1e-05,
13+
"rope_base": 10000000.0,
14+
"tie_word_embeddings": true,
15+
"vocab_size": 49155,
16+
"use_hf_tokenizer": true,
17+
"tokenizer_prepend_bos": false,
18+
"attention_bias": false,
19+
"feed_forward_bias": false
20+
}

0 commit comments

Comments
 (0)