Skip to content

Commit b8f2a5b

Browse files
committed
Update
[ghstack-poisoned]
2 parents 0d53925 + 3a2d555 commit b8f2a5b

32 files changed

+425
-170
lines changed

.buckconfig

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
root = .
99
prelude = third-party/prelude
1010
shim = shim
11+
shim_et = shim_et
1112

1213
[repository_aliases]
1314
config = prelude
1415
ovr_config = prelude
15-
toolchains = shim
16-
fbcode = shim
17-
fbcode_macros = shim
18-
fbsource = shim
16+
toolchains = shim_et
17+
fbcode = shim_et
18+
fbcode_macros = shim_et
19+
fbsource = shim_et
1920
buck = shim
2021

2122
[cxx]

build/Utils.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ function(extract_sources sources_file)
206206

207207
if(ANDROID_ABI)
208208
if("${ANDROID_ABI}" STREQUAL "arm64-v8a")
209-
set(target_platforms_arg "--target-platforms=shim//:android-arm64")
209+
set(target_platforms_arg "--target-platforms=shim_et//:android-arm64")
210210
elseif("${ANDROID_ABI}" STREQUAL "x86_64")
211-
set(target_platforms_arg "--target-platforms=shim//:android-x86_64")
211+
set(target_platforms_arg "--target-platforms=shim_et//:android-x86_64")
212212
else()
213213
message(
214214
FATAL_ERROR

build/build_android_llm_demo.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ build_android_native_library() {
6969
fi
7070
cmake --build "${CMAKE_OUT}" -j "${CMAKE_JOBS}" --target install --config "${EXECUTORCH_CMAKE_BUILD_TYPE}"
7171

72-
cmake --trace extension/android \
72+
cmake extension/android \
7373
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \
7474
-DANDROID_ABI="${ANDROID_ABI}" \
7575
-DANDROID_PLATFORM=android-26 \

examples/models/llama/attention.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,16 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
175175
self.max_batch_size = args.max_batch_size
176176
self.max_context_len = args.max_context_len
177177
self.dim = args.dim
178-
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
179-
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
180-
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
178+
self.attention_qkv_bias = args.attention_qkv_bias
179+
self.wq = nn.Linear(
180+
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
181+
)
182+
self.wk = nn.Linear(
183+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
184+
)
185+
self.wv = nn.Linear(
186+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
187+
)
181188
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
182189

183190
self.layer_id = layer_id

examples/models/llama/model_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class ModelArgs:
2121
num_experts: int = 8 # Number of experts
2222
num_activated_experts: int = 2 # Number of experts to activate
2323
attention_type: str = "mha" # Attention type, registered in attention.py
24+
attention_qkv_bias: bool = False
2425
use_kv_cache: bool = False # Use key/value cache
2526
use_sdpa_with_kv_cache_op: bool = (
2627
False # Use custom sdpa op that updates kv cache in-place

examples/models/llama/rope.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def apply_rotary_emb_to_k(
114114
return xk_out.type_as(xk)
115115

116116

117+
# Wrap apply_rotary_emb in a module to enable it to be module swapped out.
117118
class RotaryEmbedding(torch.nn.Module):
118119
def __init__(self):
119120
super().__init__()
@@ -213,14 +214,20 @@ class Rope(torch.nn.Module):
213214
def __init__(self, params: ModelArgs):
214215
super().__init__()
215216
self.params = params
217+
218+
# Choose the appropriate RoPE implementation
216219
if self.params.use_hf_rope:
217220
self.precompute_freqs_cis = hf_precompute_freqs_cis
221+
self.apply_rotary_emb = hf_apply_rotary_emb
218222
else:
219223
self.precompute_freqs_cis = partial(
220224
precompute_freqs_cis,
221225
use_scaled=self.params.use_scaled_rope,
222226
scale_factor=self.params.rope_scale_factor,
223227
)
228+
self.apply_rotary_emb = RotaryEmbedding()
229+
230+
# Precompute frequencies
224231
freqs_cos, freqs_sin = self.precompute_freqs_cis(
225232
self.params.head_dim,
226233
(
@@ -232,10 +239,6 @@ def __init__(self, params: ModelArgs):
232239
)
233240
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
234241
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
235-
if self.params.use_hf_rope:
236-
self.apply_rotary_emb = hf_apply_rotary_emb
237-
else:
238-
self.apply_rotary_emb = RotaryEmbedding()
239242

240243
def forward(
241244
self,

examples/models/llama/static_attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,22 +207,23 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
207207
self.dim = config.dim
208208
self.head_dim = config.head_dim
209209
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
210+
self.attention_qkv_bias = config.attention_qkv_bias
210211

211212
self.wqs = nn.ModuleList(
212213
[
213-
nn.Linear(self.dim, self.head_dim, bias=False)
214+
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
214215
for _ in range(self.n_heads)
215216
]
216217
)
217218
self.wks = nn.ModuleList(
218219
[
219-
nn.Linear(self.dim, self.head_dim, bias=False)
220+
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
220221
for _ in range(self.n_kv_heads)
221222
]
222223
)
223224
self.wvs = nn.ModuleList(
224225
[
225-
nn.Linear(self.dim, self.head_dim, bias=False)
226+
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
226227
for _ in range(self.n_kv_heads)
227228
]
228229
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"dim": 1536,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 8960,
5+
"n_heads": 12,
6+
"n_kv_heads": 2,
7+
"n_layers": 28,
8+
"norm_eps": 1e-06,
9+
"rope_theta": 1000000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 151936,
12+
"use_hf_rope": true,
13+
"attention_qkv_bias": true
14+
}

examples/models/qwen2_5/README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
## Summary
2+
Qwen 2.5 is the latest iteration of the Qwen series of large language models (LLMs) developed by Alibaba. At the moment, 1.5b is currently supporting, with plans in the future for adding the 0.5b and 3b versions.
3+
4+
## Instructions
5+
6+
Qwen 2.5 uses the same example code as Llama, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details.
7+
8+
All commands for exporting and running Llama on various backends should also be applicable to Qwen 2.5, by swapping the following args:
9+
```
10+
--model qwen2_5
11+
--params examples/models/qwen2_5/1_5b_config.json
12+
--checkpoint <path-to-meta-checkpoint>
13+
```
14+
15+
### Generate the Checkpoint
16+
The original checkpoint can be obtained from HuggingFace:
17+
```
18+
huggingface-cli download Qwen/Qwen2.5-1.5B
19+
```
20+
21+
We then convert it to Meta's checkpoint format:
22+
```
23+
python examples/models/qwen2_5/convert_weights.py <path-to-checkpoint-dir> <output-path>
24+
```
25+
26+
### Example export and run
27+
Here is an basic example for exporting and running Qwen 2.5, although please refer to [Llama README page](../llama/README.md) for more advanced usage.
28+
29+
Export to XNNPack, no quantization:
30+
```
31+
# No quantization
32+
# Set these paths to point to the downloaded files
33+
QWEN_CHECKPOINT=path/to/checkpoint.pth
34+
35+
python -m examples.models.llama.export_llama \
36+
--model "qwen2_5" \
37+
--checkpoint "${QWEN_CHECKPOINT:?}" \
38+
--params examples/models/qwen2_5/1_5b_config.json \
39+
-kv \
40+
--use_sdpa_with_kv_cache \
41+
-d fp32 \
42+
-X \
43+
--metadata '{"get_bos_id":151643, "get_eos_ids":[151643]}' \
44+
--output_name="qwen2_5-1_5b.pte"
45+
--verbose
46+
```
47+
48+
Run using the executor runner:
49+
```
50+
# Currently a work in progress, just need to enable HuggingFace json tokenizer in C++.
51+
# In the meantime, can run with an example Python runner with pybindings:
52+
53+
python -m examples.models.llama.runner.native
54+
--model qwen2_5
55+
--pte <path-to-pte>
56+
-kv
57+
--tokenizer <path-to-tokenizer>/tokenizer.json
58+
--tokenizer_config <path-to_tokenizer>/tokenizer_config.json
59+
--prompt "Who is the founder of Meta?"
60+
--params examples/models/qwen2_5/1_5b_config.json
61+
--max_len 64
62+
--temperature 0
63+
```
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import argparse
2+
from typing import Dict
3+
4+
import torch
5+
6+
from torchtune.models.convert_weights import get_mapped_key
7+
8+
from torchtune.training import FullModelHFCheckpointer
9+
10+
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
11+
_QWEN_2_FROM_META = {
12+
"tok_embeddings.weight": "tok_embeddings.weight",
13+
"norm.weight": "norm.scale",
14+
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
15+
"layers.{}.attention.wk.bias": "layers.{}.attn.k_proj.bias",
16+
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
17+
"layers.{}.attention.wq.bias": "layers.{}.attn.q_proj.bias",
18+
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
19+
"layers.{}.attention.wv.bias": "layers.{}.attn.v_proj.bias",
20+
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
21+
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
22+
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
23+
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
24+
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
25+
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
26+
}
27+
28+
29+
def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
30+
"""
31+
Convert a state dict from torchtune's format to Meta's format. This function
32+
doesn't handle any sharding or splitting of state dicts. It follows the
33+
state_dict IN -> state_dict OUT pattern.
34+
35+
Args:
36+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
37+
38+
Returns:
39+
Dict[str, torch.Tensor]: State dict in Meta's format.
40+
"""
41+
converted_state_dict = {}
42+
inverted_mapping_dict = {v: k for k, v in _QWEN_2_FROM_META.items()}
43+
44+
for key, value in state_dict.items():
45+
new_key = get_mapped_key(key, inverted_mapping_dict)
46+
converted_state_dict[new_key] = value
47+
48+
# 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com/QwenLM/Qwen2.5/issues/733.
49+
converted_state_dict["output.weight"] = converted_state_dict[
50+
"tok_embeddings.weight"
51+
]
52+
53+
return converted_state_dict
54+
55+
56+
def main():
57+
parser = argparse.ArgumentParser(
58+
description="Convert Qwen2 weights to Meta format."
59+
)
60+
parser.add_argument(
61+
"input_dir",
62+
type=str,
63+
help="Path to directory containing checkpoint files",
64+
)
65+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
66+
67+
args = parser.parse_args()
68+
69+
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
70+
checkpointer = FullModelHFCheckpointer(
71+
# checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/",
72+
checkpoint_dir=args.input_dir,
73+
checkpoint_files=["model.safetensors"],
74+
output_dir=".",
75+
model_type="QWEN2",
76+
)
77+
78+
print("Loading checkpoint...")
79+
sd = checkpointer.load_checkpoint()
80+
81+
print("Converting checkpoint...")
82+
sd = qwen_2_tune_to_meta(sd["model"])
83+
# torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth")
84+
85+
torch.save(sd, args.output)
86+
print(f"Checkpoint saved to {args.output}")
87+
88+
89+
if __name__ == "__main__":
90+
main()

extension/android_test/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ build_native_library() {
3333

3434
cmake --build "${CMAKE_OUT}" -j16 --target install
3535

36-
cmake extension/android \
36+
cmake --trace extension/android \
3737
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}"/build/cmake/android.toolchain.cmake \
3838
-DANDROID_ABI="${ANDROID_ABI}" \
3939
-DCMAKE_INSTALL_PREFIX=c"${CMAKE_OUT}" \

0 commit comments

Comments
 (0)