-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
90 lines (71 loc) · 3.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding:utf-8 -*-
# Author: liyanpeng
# Email: youran.xia@foxmail.com
# Datetime: 2022/9/19 20:34
# Filename: utils.py
import os
import glob
import warnings
import numpy as np
from visu_tvm import VisuGraph, VisuGraphFuseOps, VisuGraphRUF, VisuGraphMC
FO_LIST = ['FuseOps', 'AllPass']
RUF_LIST = ['RemoveUnusedFunctions', 'ToBasicBlockNormalForm', 'EliminateCommonSubexpr', 'FoldConstant',
'SimplifyInference', 'CombineParallelConv2D', 'CombineParallelDense', 'CombineParallelBatchMatmul',
'FoldScaleAxis', 'SimplifyExpr', 'CanonicalizeCast', 'CanonicalizeOps', 'FlattenAtrousConv', 'FastMath',
'ConvertLayout']
MC_LIST = ['MergeComposite']
def relay_ir2txt(context, file_name='example', is_ap=False):
save_path = 'relay_ir'
if not os.path.exists(save_path):
os.mkdir(save_path)
if is_ap:
file_name += '_ap.txt'
else:
file_name += '_bp.txt'
with open(os.path.join(save_path, file_name), 'w', encoding='utf-8') as f:
f.writelines(str(context))
def visu_relay_ir(bp_file, ap_file, save_name, with_info=False):
g = VisuGraph(txt_file=bp_file, save_name=save_name, with_info=with_info)
g.codegen()
if '_fo_' in ap_file:
g = VisuGraphFuseOps(txt_file=ap_file, save_name=save_name, with_info=with_info)
elif '_ruf_' in ap_file or '_fc_' in ap_file or '_ecs_' in ap_file or '_si_' in ap_file or '_fm_' in ap_file or \
'_se_' in ap_file or '_fac_' in ap_file or '_cc_' in ap_file or '_cl_' in ap_file or '_fsa_' in ap_file or \
'_cpc2d_' in ap_file or '_cpd_' in ap_file or '_cpbm_' in ap_file:
g = VisuGraphRUF(txt_file=ap_file, save_name=save_name, with_info=with_info)
elif '_mc_' in ap_file:
g = VisuGraphMC(txt_file=ap_file, save_name=save_name, with_info=with_info)
else:
warnings.warn("not support the pass to visu now! ==> {}".format(ap_file))
# TODO 由于没有合适的case,部分Pass优化后的Relay IR可视化可能会失败
# 有些Pass在优化神经网络(目前只在resnet18上进行了测试)的时候可能不起作用,因此Pass优化前后的可视化结果是一样的
g = VisuGraphRUF(txt_file=ap_file, save_name=save_name, with_info=with_info)
g.codegen()
def visu_relay_ir_single(ir_file, save_name, pass_name='', with_info=False):
if not pass_name:
g = VisuGraph(txt_file=ir_file, save_name=save_name, with_info=with_info)
g.codegen()
else:
if pass_name in FO_LIST:
g = VisuGraphFuseOps(txt_file=ir_file, save_name=save_name, with_info=with_info)
g.codegen()
elif pass_name in RUF_LIST:
g = VisuGraphRUF(txt_file=ir_file, save_name=save_name, with_info=with_info)
g.codegen()
elif pass_name in MC_LIST:
g = VisuGraphMC(txt_file=ir_file, save_name=save_name, with_info=with_info)
g.codegen()
else:
warnings.warn("not support the pass to visu now! ==> [pass: {}, file: {}]".format(pass_name, ir_file))
def run_all_examples(scan_dir='relay_ir', with_info=False):
bp_list = glob.glob(os.path.join(scan_dir, '*_bp.txt'))
for bp_file in bp_list:
ap_file = bp_file.replace('_bp', '_ap')
save_name = bp_file.replace('.txt', '')
print("Parsing {} and {}".format(bp_file, ap_file))
visu_relay_ir(bp_file, ap_file, save_name, with_info)
def _get_positive_scale(size):
return np.random.uniform(0.5, 1, size=size).astype("float32")
if __name__ == '__main__':
run_all_examples()
run_all_examples(scan_dir='relay_ir/tvm_case')