Skip to content

Commit 2e9c71c

Browse files
authored
Qualcomm AI Engine Direct - GA PVT (#11035)
Summary: - Add PVT example script - Add the test for PVT
1 parent ef6393a commit 2e9c71c

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4089,6 +4089,42 @@ def test_gMLP(self):
40894089
self.assertGreaterEqual(msg["top_1"], 60)
40904090
self.assertGreaterEqual(msg["top_5"], 90)
40914091

4092+
def test_pvt(self):
4093+
if not self.required_envs([self.image_dataset]):
4094+
self.skipTest("missing required envs")
4095+
4096+
cmds = [
4097+
"python",
4098+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/pvt.py",
4099+
"--dataset",
4100+
self.image_dataset,
4101+
"--artifact",
4102+
self.artifact_dir,
4103+
"--build_folder",
4104+
self.build_folder,
4105+
"--device",
4106+
self.device,
4107+
"--model",
4108+
self.model,
4109+
"--ip",
4110+
self.ip,
4111+
"--port",
4112+
str(self.port),
4113+
]
4114+
if self.host:
4115+
cmds.extend(["--host", self.host])
4116+
4117+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4118+
with Listener((self.ip, self.port)) as listener:
4119+
conn = listener.accept()
4120+
p.communicate()
4121+
msg = json.loads(conn.recv())
4122+
if "Error" in msg:
4123+
self.fail(msg["Error"])
4124+
else:
4125+
self.assertGreaterEqual(msg["top_1"], 65)
4126+
self.assertGreaterEqual(msg["top_5"], 85)
4127+
40924128
def test_regnet(self):
40934129
if not self.required_envs([self.image_dataset]):
40944130
self.skipTest("missing required envs")

examples/qualcomm/oss_scripts/pvt.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import logging
9+
import os
10+
from multiprocessing.connection import Client
11+
12+
import numpy as np
13+
14+
import torch
15+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16+
from executorch.examples.qualcomm.utils import (
17+
build_executorch_binary,
18+
get_imagenet_dataset,
19+
make_output_dir,
20+
parse_skip_delegation_node,
21+
setup_common_args_and_variables,
22+
SimpleADB,
23+
topk_accuracy,
24+
)
25+
from transformers import AutoModelForImageClassification
26+
27+
28+
def main(args):
29+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
30+
31+
# ensure the working directory exist.
32+
os.makedirs(args.artifact, exist_ok=True)
33+
34+
if not args.compile_only and args.device is None:
35+
raise RuntimeError(
36+
"device serial is required if not compile only. "
37+
"Please specify a device serial by -s/--device argument."
38+
)
39+
40+
data_num = 100
41+
if args.ci:
42+
inputs = [(torch.rand(1, 3, 224, 224),)]
43+
logging.warning(
44+
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
45+
)
46+
else:
47+
inputs, targets, input_list = get_imagenet_dataset(
48+
dataset_path=f"{args.dataset}",
49+
data_size=data_num,
50+
image_shape=(256, 256),
51+
crop_size=224,
52+
)
53+
54+
module = (
55+
AutoModelForImageClassification.from_pretrained("Zetatech/pvt-tiny-224")
56+
.eval()
57+
.to("cpu")
58+
)
59+
60+
pte_filename = "pvt_qnn_q8"
61+
build_executorch_binary(
62+
module.eval(),
63+
inputs[0],
64+
args.model,
65+
f"{args.artifact}/{pte_filename}",
66+
inputs,
67+
skip_node_id_set=skip_node_id_set,
68+
skip_node_op_set=skip_node_op_set,
69+
quant_dtype=QuantDtype.use_8a8w,
70+
shared_buffer=args.shared_buffer,
71+
)
72+
73+
if args.compile_only:
74+
return
75+
76+
adb = SimpleADB(
77+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
78+
build_path=f"{args.build_folder}",
79+
pte_path=f"{args.artifact}/{pte_filename}.pte",
80+
workspace=f"/data/local/tmp/executorch/{pte_filename}",
81+
device_id=args.device,
82+
host_id=args.host,
83+
soc_model=args.model,
84+
shared_buffer=args.shared_buffer,
85+
)
86+
adb.push(inputs=inputs, input_list=input_list)
87+
adb.execute()
88+
89+
# collect output data
90+
output_data_folder = f"{args.artifact}/outputs"
91+
make_output_dir(output_data_folder)
92+
93+
adb.pull(output_path=args.artifact)
94+
95+
# top-k analysis
96+
predictions = []
97+
for i in range(data_num):
98+
predictions.append(
99+
np.fromfile(
100+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
101+
)
102+
)
103+
104+
k_val = [1, 5]
105+
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
106+
if args.ip and args.port != -1:
107+
with Client((args.ip, args.port)) as conn:
108+
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
109+
else:
110+
for i, k in enumerate(k_val):
111+
print(f"top_{k}->{topk[i]}%")
112+
113+
114+
if __name__ == "__main__":
115+
parser = setup_common_args_and_variables()
116+
117+
parser.add_argument(
118+
"-d",
119+
"--dataset",
120+
help=(
121+
"path to the validation folder of ImageNet dataset. "
122+
"e.g. --dataset imagenet-mini/val "
123+
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
124+
),
125+
type=str,
126+
required=False,
127+
)
128+
129+
parser.add_argument(
130+
"-a",
131+
"--artifact",
132+
help="path for storing generated artifacts by this example. " "Default ./pvt",
133+
default="./pvt",
134+
type=str,
135+
)
136+
137+
args = parser.parse_args()
138+
try:
139+
main(args)
140+
except Exception as e:
141+
if args.ip and args.port != -1:
142+
with Client((args.ip, args.port)) as conn:
143+
conn.send(json.dumps({"Error": str(e)}))
144+
else:
145+
raise Exception(e)

0 commit comments

Comments
 (0)