Skip to content

Commit 310466a

Browse files
Integrate byoc preprocess in collage and benchmark (apache#26)
Integrate implicit call of BYOC preprocessing module into collage tunning module and enable benchmark script for adreno targets. Benchmark results: **Networks | OpenCL texture | OpenCLML | Collage** resnet-18-float32 | 0.010584622 | 0.00720695 | 0.007289728 resnet-18-float16 | 0.007052029 | 0.0045642 | 0.004857585 resnet-34-float32 | 0.016259185 | 0.01242092 | 0.013071063 resnet-34-float16 | 0.011350326 | 0.0073473 | 0.00796802 resnet-50-float32 | 0.019188419 | 0.02085548 | 0.018910226 resnet-50-float16 | 0.01338978 | 0.01199576 | 0.011089206 densenet-121-float32 | 0.025430062 | 0.01798478 | 0.013212844 densenet-121-float16 | 0.012384599 | 0.01101491 | 0.008722716 inception_v3-float32 | 0.040408253 | 0.02229727 | 0.022636675 inception_v3-float16 | 0.029910533 | 0.01368941 | 0.014519823 mobilenet-float32 | 0.004093148 | 0.00367917 | 0.003189258 mobilenet-float16 | 0.00280268 | 0.00244494 | 0.002101514 </body> </html> Co-authored-by: krishnaraj36 <quic_kvegiraj@quicinc.com>
1 parent 51260a6 commit 310466a

File tree

8 files changed

+525
-242
lines changed

8 files changed

+525
-242
lines changed
Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Compares Collage with various other baselines."""
19+
import argparse
20+
import tvm
21+
from tvm import relay
22+
import logging
23+
import os
24+
import sys
25+
import numpy as np
26+
from tvm.relay import testing
27+
from tvm.contrib.utils import tempdir
28+
from tvm import rpc
29+
from tvm.relay.build_module import bind_params_by_name
30+
from tvm import autotvm
31+
from tvm.runtime.vm import VirtualMachine
32+
import tvm.contrib.graph_executor as runtime
33+
from tvm.contrib import utils, ndk
34+
from tvm.relay.collage.collage import *
35+
from tvm.relay.op.contrib import clml
36+
37+
logging.basicConfig(level=logging.INFO)
38+
39+
40+
###
41+
### How aggressively to look for candidates?
42+
###
43+
TVM_MAX_DEPTH = 8
44+
BYOC_MAX_DEPTH = 8
45+
46+
##
47+
## Default config definition
48+
##
49+
HOST = tvm.target.Target("llvm -mtriple=arm64-linux-android")
50+
OPENCL = tvm.target.Target("opencl -device=adreno", HOST)
51+
NDK_CC = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++")
52+
53+
54+
def print_progress(msg):
55+
"""print progress message
56+
57+
Parameters
58+
----------
59+
msg: str
60+
The message to print
61+
"""
62+
sys.stdout.write(msg + "\r")
63+
sys.stdout.flush()
64+
65+
66+
def tune_tasks(
67+
tasks,
68+
measure_option,
69+
tuner="xgb",
70+
n_trial=1024,
71+
early_stopping=None,
72+
log_filename="tuning.log",
73+
):
74+
from tvm.autotvm.tuner import XGBTuner
75+
76+
tmp_log_file = log_filename + ".tmp"
77+
78+
for i, tsk in enumerate(reversed(tasks)):
79+
print("Task: ", tsk)
80+
prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
81+
82+
# create tuner
83+
if tuner == "xgb":
84+
tuner_obj = XGBTuner(tsk, loss_type="reg")
85+
elif tuner == "xgb_knob":
86+
tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="knob")
87+
elif tuner == "xgb_itervar":
88+
tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="itervar")
89+
elif tuner == "xgb_curve":
90+
tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="curve")
91+
elif tuner == "xgb_rank":
92+
tuner_obj = XGBTuner(tsk, loss_type="rank")
93+
elif tuner == "xgb_rank_knob":
94+
tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob")
95+
elif tuner == "xgb_rank_itervar":
96+
tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="itervar")
97+
elif tuner == "xgb_rank_curve":
98+
tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="curve")
99+
elif tuner == "xgb_rank_binary":
100+
tuner_obj = XGBTuner(tsk, loss_type="rank-binary")
101+
elif tuner == "xgb_rank_binary_knob":
102+
tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="knob")
103+
elif tuner == "xgb_rank_binary_itervar":
104+
tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="itervar")
105+
elif tuner == "xgb_rank_binary_curve":
106+
tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="curve")
107+
elif tuner == "ga":
108+
tuner_obj = GATuner(tsk, pop_size=50)
109+
elif tuner == "random":
110+
tuner_obj = RandomTuner(tsk)
111+
elif tuner == "gridsearch":
112+
tuner_obj = GridSearchTuner(tsk)
113+
else:
114+
raise ValueError("Invalid tuner: " + tuner)
115+
116+
tsk_trial = min(n_trial, len(tsk.config_space))
117+
tuner_obj.tune(
118+
n_trial=tsk_trial,
119+
early_stopping=early_stopping,
120+
measure_option=measure_option,
121+
callbacks=[
122+
autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
123+
autotvm.callback.log_to_file(tmp_log_file),
124+
],
125+
)
126+
127+
autotvm.record.pick_best(tmp_log_file, log_filename)
128+
129+
130+
########### Collage Drivers ###########
131+
132+
133+
def compile_and_run(label, model, targets, inputs):
134+
"""Compile model for target and run it with profiling."""
135+
logging.info(f"Compiling {model['name']} using {label} with {targets}...")
136+
mod = model["mod"]
137+
exe = tvm.relay.vm.compile(mod, target=targets, params=model["params"])
138+
lib = exe.mod
139+
temp = utils.tempdir()
140+
dso_binary = "dev_lib_cl.so"
141+
dso_binary_path = temp.relpath(dso_binary)
142+
logging.info(f"Exporting library to {dso_binary_path}...")
143+
lib.export_library(dso_binary_path, cc=NDK_CC)
144+
tracker = rpc.connect_tracker(args.host, args.port)
145+
remote = tracker.request(args.rpc_key, priority=0, session_timeout=600)
146+
ctx = remote.cl(0)
147+
remote.upload(dso_binary_path)
148+
rlib = remote.load_module(dso_binary)
149+
vm_factory = tvm.runtime.vm.VirtualMachine(rlib, ctx, "naive")
150+
func_name = "main"
151+
main_args = {v.name_hint: arg_for(v.checked_type, ctx) for v in mod[func_name].params}
152+
profile = vm_factory.benchmark(
153+
ctx, repeat=5, number=20, min_repeat_ms=0, func_name=func_name, **main_args
154+
)
155+
return profile.mean
156+
157+
158+
def collage(model, input_data, tune_log=""):
159+
"""Run the Collage partitioner for a set of Opencl Adreno related targets and profile the result"""
160+
logging.info(f"collage | {model['name']}")
161+
logging.info("-------------- BEGIN ORIGINAL --------------")
162+
logging.info(model["mod"])
163+
logging.info("-------------- END ORIGINAL ----------------")
164+
with autotvm.apply_history_best(tune_log):
165+
targets = []
166+
targets.append(OPENCL)
167+
use_fp16 = model["main_dtype"] == "float16"
168+
targets.append(tvm.target.Target("clml", HOST))
169+
170+
# Register byoc fusion style for compiler with available
171+
# options [compiler.NoFusion | compiler.TVMFusion | compiler.MaxDepthFusion]
172+
config = {
173+
"relay.collage.tvm_max_depth": TVM_MAX_DEPTH,
174+
"relay.collage.byoc_max_depth": BYOC_MAX_DEPTH,
175+
"relay.collage.byoc_fusion_style": ["clml.NoFusion"],
176+
}
177+
logging.info(f"Using PassContext(config={config}")
178+
ctxt = tvm.transform.PassContext(config=config)
179+
config = tvm.target.make_compilation_config(ctxt, targets)
180+
with ctxt:
181+
mod = model["mod"]
182+
"""Collage partition with tvm opencl and clml target on rpc device"""
183+
mod = tvm.relay.transform.CollagePartition(
184+
config,
185+
cost_estimator=CostEstimator(
186+
host=args.host, port=args.port, rpc_key=args.rpc_key, ndk_cc=NDK_CC
187+
),
188+
)(mod)
189+
partitioned_model = model.copy()
190+
partitioned_model["mod"] = mod
191+
logging.info("-------------- BEGIN PARTITIONED --------------")
192+
logging.info(partitioned_model["mod"])
193+
logging.info("-------------- END PARTITIONED ----------------")
194+
return compile_and_run("collage", partitioned_model, targets, input_data)
195+
196+
197+
def just_clml(model, input_data, tune_log=""):
198+
"""Run partition_for_clml, complete the compilation with TVM, and profile the result."""
199+
logging.info(f"just_clml | {model['name']}")
200+
logging.info("-------------- BEGIN ORIGINAL --------------")
201+
logging.info(model["mod"])
202+
logging.info("-------------- END ORIGINAL ----------------")
203+
with autotvm.apply_history_best(tune_log):
204+
with tvm.transform.PassContext(opt_level=3):
205+
logging.info("Partitioning for CLML...")
206+
mod = tvm.relay.op.contrib.clml.partition_for_clml(model["mod"], model["params"])
207+
partitioned_model = model.copy()
208+
partitioned_model["mod"] = mod
209+
logging.info("-------------- BEGIN PARTITIONED --------------")
210+
logging.info(partitioned_model["mod"])
211+
logging.info("-------------- END PARTITIONED ----------------")
212+
targets = []
213+
targets.append(OPENCL)
214+
targets.append(tvm.target.Target("clml", HOST))
215+
return compile_and_run("just_clml", partitioned_model, OPENCL, input_data)
216+
217+
218+
def just_tvm(model, input_data, tune_log=""):
219+
"""Compile and profile using vanilla TVM."""
220+
logging.info(f"just_tvm | {model['name']}")
221+
logging.info("-------------- BEGIN ORIGINAL --------------")
222+
logging.info(model["mod"])
223+
logging.info("-------------- END ORIGINAL ----------------")
224+
with autotvm.apply_history_best(tune_log):
225+
with tvm.transform.PassContext(opt_level=3):
226+
return compile_and_run("just_tvm", model, OPENCL, input_data)
227+
228+
229+
def get_model(model_name, dtype):
230+
231+
if "mobilenet" in model_name:
232+
mod, params = testing.mobilenet.get_workload(batch_size=1, dtype=dtype)
233+
elif "resnet" in model_name:
234+
n_layer = int(model_name.split("-")[1])
235+
mod, params = testing.resnet.get_workload(num_layers=n_layer, batch_size=1, dtype=dtype)
236+
elif model_name == "inception_v3":
237+
input_shape = (1, 3, 299, 299)
238+
mod, params = testing.inception_v3.get_workload(batch_size=1, dtype=dtype)
239+
elif "vgg" in model_name:
240+
n_layer = int(model_name.split("-")[1])
241+
mod, params = testing.vgg.get_workload(num_layers=n_layer, batch_size=1, dtype=dtype)
242+
elif "densenet" in model_name:
243+
n_layer = int(model_name.split("-")[1])
244+
mod, params = testing.densenet.get_workload(
245+
densenet_size=n_layer, batch_size=1, dtype=dtype
246+
)
247+
elif "squeezenet" in model_name:
248+
version = model_name.split("_v")[1]
249+
mod, params = testing.squeezenet.get_workload(batch_size=1, version=version, dtype=dtype)
250+
251+
initializer = tvm.relay.testing.init.Xavier()
252+
for param_name in list(params.keys()):
253+
filter_data = np.zeros(params[param_name].shape).astype(params[param_name].dtype)
254+
if len(filter_data.shape) > 1:
255+
initializer("weight", filter_data)
256+
else:
257+
initializer("bias", filter_data)
258+
params[param_name] = tvm.nd.array(filter_data)
259+
260+
if params:
261+
mod["main"] = bind_params_by_name(mod["main"], params)
262+
mod = tvm.relay.transform.FoldConstant()(mod)
263+
return {
264+
"name": model_name,
265+
"input_shapes": {"data": [1, 3, 224, 224]},
266+
"input_dtypes": {"data": dtype},
267+
"mod": mod,
268+
"params": params,
269+
"main_dtype": dtype,
270+
}
271+
272+
273+
########### Runners ###########
274+
def evaluate_network(model_name, dtype):
275+
print("Network evaluating .. " + model_name + " " + dtype)
276+
np.random.seed(0)
277+
model = get_model(model_name, dtype)
278+
tune_log = "adreno_v0.01.log"
279+
if args.tune:
280+
# Auto Tuning
281+
tune_log = "adreno-" + model_name + "-" + dtype + ".log"
282+
tuning_options = {
283+
"log_filename": tune_log,
284+
"early_stopping": None,
285+
"measure_option": autotvm.measure_option(
286+
builder=autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15),
287+
runner=autotvm.RPCRunner(
288+
args.rpc_key,
289+
host=args.host,
290+
port=args.port,
291+
number=3,
292+
timeout=600,
293+
),
294+
),
295+
}
296+
tasks = autotvm.task.extract_from_program(
297+
net, target=OPENCL, target_host=HOST, params=params
298+
)
299+
tune_tasks(tasks, **tuning_options)
300+
301+
print_progress("%-20s building..." % network)
302+
input_data = {}
303+
for name, shape in model["input_shapes"].items():
304+
input_data[name] = np.random.uniform(-1.0, 1.0, shape).astype(model["input_dtypes"][name])
305+
clml_time = just_clml(model, input_data, tune_log)
306+
tvm_time = just_tvm(model, input_data, tune_log)
307+
308+
"""Run Collage for tvm and clml compiler target."""
309+
collage_time = collage(model, input_data, tune_log)
310+
return (tvm_time, clml_time, collage_time)
311+
312+
313+
if __name__ == "__main__":
314+
parser = argparse.ArgumentParser()
315+
parser.add_argument(
316+
"--network",
317+
type=str,
318+
choices=[
319+
"resnet-18",
320+
"resnet-34",
321+
"resnet-50",
322+
"vgg-16",
323+
"vgg-19",
324+
"densenet-121",
325+
"inception_v3",
326+
"mobilenet",
327+
"squeezenet_v1.0",
328+
"squeezenet_v1.1",
329+
],
330+
help="The name of neural network",
331+
)
332+
parser.add_argument("--host", type=str, default="127.0.0.1")
333+
parser.add_argument("--port", type=int, default=9190)
334+
parser.add_argument("--rpc-key", type=str, default="android")
335+
parser.add_argument(
336+
"--dtype",
337+
type=str,
338+
choices=["float32", "float16"],
339+
help="The data type of neural network",
340+
)
341+
parser.add_argument("--tune", type=bool, default=False)
342+
args = parser.parse_args()
343+
344+
if args.network is None:
345+
networks = [
346+
"resnet-18",
347+
"resnet-34",
348+
"resnet-50",
349+
# "vgg-16",
350+
# "vgg-19",
351+
"densenet-121",
352+
"inception_v3",
353+
"mobilenet",
354+
"squeezenet_v1.0",
355+
"squeezenet_v1.1",
356+
]
357+
else:
358+
networks = [args.network]
359+
360+
if args.dtype is None:
361+
dtypes = ["float32", "float16"]
362+
else:
363+
dtypes = [args.dtype]
364+
365+
results = {}
366+
net_results = []
367+
for network in networks:
368+
for dtype in dtypes:
369+
ftime = evaluate_network(network, dtype)
370+
results[network + "-" + dtype] = ftime
371+
# net_results.append([network + "-" + dtype] + list(ftime))
372+
# np.savetxt("results.txt", np.array(net_results), fmt="%s")
373+
374+
print("----------------------------------------------------------------------")
375+
print(
376+
"%-30s %-20s %-20s %-20s"
377+
% ("Network Name", "TVM Opencl Time", "CLML Time", "Collage - TVM/CLML Time")
378+
)
379+
print("----------------------------------------------------------------------")
380+
for key, val in results.items():
381+
print(
382+
"%-30s %-20s %-20s %-20s"
383+
% (key, "%.2f ms" % val[0], "%.2f ms" % val[1], "%.2f ms" % val[2])
384+
)

0 commit comments

Comments
 (0)