Skip to content

Commit 6bf4e5b

Browse files
jackzhxngZonglin Peng
authored and
Zonglin Peng
committed
Add Phi-4-mini-instruct (#8856)
1 parent 338d936 commit 6bf4e5b

File tree

7 files changed

+136
-4
lines changed

7 files changed

+136
-4
lines changed

.ci/scripts/test_model.sh

+8
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ test_model() {
100100
rm "./${MODEL_NAME}.pte"
101101
return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears.
102102
fi
103+
if [[ "${MODEL_NAME}" == "phi4_mini" ]]; then
104+
# Install requirements for export_llama
105+
bash examples/models/llama/install_requirements.sh
106+
# Test export_llama script: python3 -m examples.models.llama.export_llama.
107+
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi-4-mini/config.json
108+
run_portable_executor_runner
109+
rm "./${MODEL_NAME}.pte"
110+
fi
103111

104112
# Export a basic .pte and run the model.
105113
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}"

examples/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"llava": ("llava", "LlavaModel"),
3636
"efficient_sam": ("efficient_sam", "EfficientSAM"),
3737
"qwen2_5": ("qwen2_5", "Qwen2_5Model"),
38+
"phi4_mini": ("phi4_mini", "Phi4MiniModel"),
3839
}
3940

4041
__all__ = [

examples/models/llama/export_llama_lib.py

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
"llama3_2",
9494
"static_llama",
9595
"qwen2_5",
96+
"phi4_mini",
9697
]
9798
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
9899

examples/models/llama/model_args.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ModelArgs:
3838
apply_embedding: bool = True # Use embedding inside the transformer
3939
apply_output: bool = True # Use output layer (unembedding) inside the transformer
4040
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
41+
partial_rotary_factor: float = 1.0
4142
rope_theta: Optional[float] = (
4243
None # The official name to override self.rope_freq_base.
4344
)

examples/models/llama/rope.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,21 @@ def forward(
134134

135135

136136
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77
137-
def hf_precompute_freqs_cis(dim: int, end: int, theta: float):
137+
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
138+
# Current only support non-long rope.
139+
def hf_precompute_freqs_cis(
140+
dim: int, end: int, theta: float, partial_rotary_factor: float = 1.0
141+
):
142+
# Partial rotary embeddings.
143+
dim = int(dim * partial_rotary_factor)
144+
145+
# Short factor scaling.
138146
freqs = 1.0 / (
139147
theta
140148
** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim)
141149
)
150+
# TODO: support long factor scaling.
151+
142152
# pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
143153
t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as(
144154
freqs # pyre-ignore
@@ -180,8 +190,13 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
180190
"""
181191
cos = cos.unsqueeze(unsqueeze_dim)
182192
sin = sin.unsqueeze(unsqueeze_dim)
183-
q_embed = (q * cos) + (rotate_half(q) * sin)
184-
k_embed = (k * cos) + (rotate_half(k) * sin)
193+
194+
rotary_dim = cos.shape[-1]
195+
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
196+
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
197+
198+
q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
199+
k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
185200
return q_embed, k_embed
186201

187202

@@ -217,7 +232,10 @@ def __init__(self, params: ModelArgs):
217232

218233
# Choose the appropriate RoPE implementation
219234
if self.params.use_hf_rope:
220-
self.precompute_freqs_cis = hf_precompute_freqs_cis
235+
self.precompute_freqs_cis = partial(
236+
hf_precompute_freqs_cis,
237+
partial_rotary_factor=self.params.partial_rotary_factor,
238+
)
221239
self.apply_rotary_emb = hf_apply_rotary_emb
222240
else:
223241
self.precompute_freqs_cis = partial(
+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"dim": 3072,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 8192,
5+
"n_heads": 24,
6+
"n_kv_heads": 8,
7+
"n_layers": 32,
8+
"norm_eps": 1e-05,
9+
"rope_theta": 10000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 200064,
12+
"use_hf_rope": true,
13+
"partial_rotary_factor": 0.75,
14+
"attention_qkv_bias": false
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import argparse
2+
from typing import Dict
3+
4+
import torch
5+
6+
from torchtune.models.convert_weights import get_mapped_key
7+
8+
from torchtune.training import FullModelHFCheckpointer
9+
10+
11+
# Standard _FROM_META weight mapping of Meta weights to TorchTune.
12+
_PHI_4_FROM_META = {
13+
"tok_embeddings.weight": "tok_embeddings.weight",
14+
"norm.weight": "norm.scale",
15+
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
16+
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
17+
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
18+
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
19+
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
20+
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
21+
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
22+
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
23+
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
24+
}
25+
26+
27+
def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
28+
"""
29+
Convert a state dict from torchtune's format to Meta's format. This function
30+
doesn't handle any sharding or splitting of state dicts. It follows the
31+
state_dict IN -> state_dict OUT pattern.
32+
33+
Args:
34+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
35+
36+
Returns:
37+
Dict[str, torch.Tensor]: State dict in Meta's format.
38+
"""
39+
converted_state_dict = {}
40+
inverted_mapping_dict = {v: k for k, v in _PHI_4_FROM_META.items()}
41+
42+
for key, value in state_dict.items():
43+
new_key = get_mapped_key(key, inverted_mapping_dict)
44+
converted_state_dict[new_key] = value
45+
46+
# Input and output embeddings are tied.
47+
converted_state_dict["output.weight"] = converted_state_dict[
48+
"tok_embeddings.weight"
49+
]
50+
51+
return converted_state_dict
52+
53+
54+
def main():
55+
parser = argparse.ArgumentParser(
56+
description="Convert Phi-4-mini weights to Meta format."
57+
)
58+
parser.add_argument(
59+
"input_dir",
60+
type=str,
61+
help="Path to directory containing checkpoint files",
62+
)
63+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
64+
65+
args = parser.parse_args()
66+
67+
checkpointer = FullModelHFCheckpointer(
68+
checkpoint_dir=args.input_dir,
69+
checkpoint_files=[
70+
"model-00001-of-00002.safetensors",
71+
"model-00002-of-00002.safetensors",
72+
],
73+
output_dir=".",
74+
model_type="PHI3_MINI",
75+
)
76+
77+
print("Loading checkpoint...")
78+
sd = checkpointer.load_checkpoint()
79+
80+
print("Converting checkpoint...")
81+
sd = phi_4_tune_to_meta(sd["model"])
82+
83+
torch.save(sd, args.output)
84+
print(f"Checkpoint saved to {args.output}")
85+
86+
87+
if __name__ == "__main__":
88+
main()

0 commit comments

Comments
 (0)