Skip to content

Commit 0e1cd57

Browse files
Ziheng Jiangpiiswrong
Ziheng Jiang
authored andcommitted
MXNet Profiler (apache#3163)
* NNVM Refactor (apache#3194) * Init nnvm change * temp checkin * Move TShape to NNVM * Redirect Symbolic API to NNVM * Add Op Prop Adapter * Finish migrate in shape infer * Pass all symbolic test * temp commit * enable aux data * [EXEC] Basic version of exec for forward only * [EXEC] Enable most optimizations, still wait grad and context * fix legacy op with latest one * Update NNVM NodeRef * Adapt to newer interface * ALl registry of backop is complete * temp commit * Hack finish backward pass * [EXEC] One day pass * [EXEC] Pass all operator unittest * [EXEC] enable model parallel * Fully pass all legacy tests * Remove legacy symbolic code * update news * Make travis compile * Fix python3 * Update viz module to new json format * [NNVM] Imperative Invoke (apache#3208) * [Engine] Deduplicate Variable Util * [NNVM] NNVM Imperative Invoke * [NNVM] Imperative improve speed * fix * fix * [scala] link libnnvm.a (apache#3214) * [PYTHON] Optional Cython Module for Symbols (apache#3242) * [CYTHON] Checkin cython enhancement * fix lint * [DOC] Move common doc to base * [EXEC] Support fcompute (apache#3249) * [EXEC] Support fcompute * Fix lint * fix lint * [OP] Add alias support (apache#3261) * Fix path in setup.py (apache#3276) * Fix path in setup.py * revert the nnvm version * [WIP] Element wise op refactor (apache#3245) * [OPERATOR] Refactor Unary Ops * [OPERATOR] Refactor Binary Scalar Ops * Use alias * update nnvm version (apache#3290) * Fix breaking changes after pull master (apache#3291) * [CYTHON] Cython module for NDArray (apache#3292) * [NDARRAY] Cython module for ndarray * More strict tests * [NNVM] change of attr to set_attr (apache#3303) * Update run_test.sh * add nnvm cmake with windows (apache#3255) * [WIP] binary broadcast wip (apache#3301) * [WIP] binary broadcast wip [OPERATOR] Binary Broadcast ops fix lint lint fix max and min update submodule before removing reduce axis broad cast reduce ops * update * fix * fix warning * fix * x (apache#3308) * [IO] Python based ImageIter and Augumenter (apache#3227) * [IO] Python based ImageIter and Augumenter * fix * fix * fix * [OPT] NNVM Optimizer (apache#3314) * fix cpython in windows (apache#3309) * Add Mathematical functions (apache#3317) * fix image io * add hypot degrees radians cosh sinh tanh arcsinh arccosh arctanh (apache#3335) * add recent examples, collect some missing tutorials (apache#3340) * Improving docs & utilities for distributed training example. (apache#3341) * add init dict * disable SSE for arm hardware e.g. Raspberry Pi (apache#3346) * Add channel_ to Shape2D calculation (apache#3181) * Add channel_ to Shape2D calculation * scalapkg, add example multitask (apache#3186) * RNN cell demo with ptb LSTM language model (apache#3197) * rnn-cell demo (push to server for testing) * a running example with cuDNN RNN cell * Bulk lint fix (apache#3211) * [TENSOR] Add FlatTo1D for all elementwise ops (apache#3238) * Fix little bug on context (apache#3202) * add PennTreeBank Language Model using lstm model in R (apache#2659) * Add function 'print_summary' and some revise (apache#3161) * Add function 'print_summary' and some revise Add function 'print_summary' for print detail information of network, and format argument was add in 'plot_network'. You can use 'print_summary' like: """ net = get_symbol(1000) shape = {'softmax_label': (64, 12), 'data': (64, 3, 224, 224)} mx.viz.print_summary(net, shape=shape) """ If without shape, the number of arguments would be nonsense currently. * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Added my CmakeLists.txt for caffe plugin, etc. * Revert "fix travis scala test config" (apache#3246) This reverts parts of commit 3e15f62. Reenables testing the Julia bindings * [Scala] Code generation for Symbol (apache#3217) [scala] auto-generate Symbol functions * fix spelling errors (apache#3258) Also align grammar and punctuation in short descriptions of features * fix typo in run_test.sh (apache#3260) * Copy slice along arbitrary axis (apache#3259) * rnn-cell demo (push to server for testing) * a running example with cuDNN RNN cell * add copyslice along arbitrary axis for NDArray * copy_slice_to as an ndarray operator * Python interface to the _copy_slice_to operator * fix lint error * Enable concatenation for dim-1 vectors (apache#3264) * fix PReLU backward computing (apache#3277) * Add `reverse` option in Reshape (apache#3280) * add scala example, end2end neural-style (apache#3267) add scala example, end2end neural-style * Improve multi-GPU performance (apache#3241) * update kvstore * update model.py * bandwith tool * update readme * tiny * fix lint * fix batch size of dist_device_sync * fix * fix perf problem of kvstore when only using a single device * roll back to previous strategy how to choose update_on_kvsotre * add an optionl MXNET_ENABLE_GPU_P2P to control whether or not use p2p * update dmlccore (apache#3293) * Fix newer version of gtest and cpptest (apache#3294) * when set use_global_stats then do not use cudnn (apache#3289) * when set use_global_stats then do not use cudnn * fix batch norm with use_global_stats * Fix req+reserve_space in cudnn_rnn (apache#3274) Fix req Fix reserve_space Allocate reserve_space using Storage * add cudnn off option in Convolution (apache#3270) * add support for building on power (apache#3302) * add recent examples, collect some missing tutorials (apache#3340) * CMake for caffe plugin * Fix metric & im2rec.py * [Scala] Nnvm ops for NDArray & Symbol (apache#3361) * [scala] nnvm op support * [scala] remove unused codes * fix scala native code style * [R] Fix the R interface (apache#3334) * [R] Fix the R interface. remove man * Fix BN legacy issue * Locate compiled library on Windows (apache#3369) * Fix metric & im2rec.py (apache#3375) image io fix * Update legacy op FBackwardInGradIndex (apache#3376) * Update legacy op FBackwardInGradIndex * fix test * Fix for LRN Layer (apache#3366) * fixed cpu forward bug * added out_data[lrn_enum::kOut] as backward req. * removed lint * removed duplicate out_data[lrn_enum::kTmpNorm], * removed inplace option * add backward index * include some special functions (apache#3337) - gamma - gammaln - log1p - expm1 * fix kv build (apache#3385) * initial profiler branch based on dmlc/mxnet:nnvm * [profiler] add profiler & modify engine API * [profiler] add USE_PROFILER compile flag & modify code for changed engine api * [profiler] add c_api interface & modify graph_executor * [profiler] add python api * [profiler] typo & lint error * [profiler] reduce overhead & add PROFIELR_MESSAGE_FUNCNAME macro * [profiler] remove profiling argument from PushSync/PushAsync * [profiler] refactor profiler.h/.cc * [profiler] improve readability * [profiler] typo && add TODO comment * [profiler] fix ndarray op name & add WaitForVar back * [profiler] add example/profiler/profiler_ndarray.py * [profiler] fix memleak by using op->name * [profiler] fix lint * [profiler] fix lint
1 parent a02a1e7 commit 0e1cd57

31 files changed

+1261
-75
lines changed

Makefile

+5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ else
4545
NVCCFLAGS = -std=c++11 -Xcompiler -D_FORCE_INLINES -g -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
4646
endif
4747

48+
# CFLAGS for profiler
49+
ifeq ($(USE_PROFILER), 1)
50+
CFLAGS += -DMXNET_USE_PROFILER=1
51+
endif
52+
4853
ifndef LINT_LANG
4954
LINT_LANG="all"
5055
endif

example/profiler/profiler_executor.py

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import mxnet as mx
2+
import argparse
3+
import os, sys
4+
import time
5+
import numpy as np
6+
from mxnet import profiler
7+
8+
9+
def parse_args():
10+
parser = argparse.ArgumentParser(description='Set network parameters for benchmark test.')
11+
parser.add_argument('--profile_filename', type=str, default='profile_executor_5iter.json')
12+
parser.add_argument('--iter_num', type=int, default=5)
13+
parser.add_argument('--fc1', type=int, default=128)
14+
parser.add_argument('--fc2', type=int, default=128)
15+
parser.add_argument('--fc3', type=int, default=128)
16+
parser.add_argument('--fc4', type=int, default=128)
17+
return parser.parse_args()
18+
19+
20+
def _download(data_dir):
21+
if not os.path.isdir(data_dir):
22+
os.system("mkdir " + data_dir)
23+
os.chdir(data_dir)
24+
if (not os.path.exists('train-images-idx3-ubyte')) or \
25+
(not os.path.exists('train-labels-idx1-ubyte')) or \
26+
(not os.path.exists('t10k-images-idx3-ubyte')) or \
27+
(not os.path.exists('t10k-labels-idx1-ubyte')):
28+
os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip")
29+
os.system("unzip -u mnist.zip; rm mnist.zip")
30+
os.chdir("..")
31+
32+
33+
def get_data(data_shape):
34+
data_dir = "mnist/"
35+
batch_size = 128
36+
if '://' not in data_dir:
37+
_download(data_dir)
38+
39+
train = mx.io.MNISTIter(
40+
image = data_dir + "train-images-idx3-ubyte",
41+
label = data_dir + "train-labels-idx1-ubyte",
42+
input_shape = data_shape,
43+
batch_size = batch_size,
44+
shuffle = True,
45+
)
46+
47+
val = mx.io.MNISTIter(
48+
image = data_dir + "t10k-images-idx3-ubyte",
49+
label = data_dir + "t10k-labels-idx1-ubyte",
50+
input_shape = data_shape,
51+
batch_size = batch_size,
52+
)
53+
54+
return (train, val)
55+
56+
def get_symbol():
57+
data = mx.symbol.Variable('data')
58+
fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=args.fc1)
59+
act1 = mx.symbol.Activation(data=fc1, name='relu1', act_type='relu')
60+
fc2 = mx.symbol.FullyConnected(data=act1 , name='fc2', num_hidden=args.fc2)
61+
act2 = mx.symbol.Activation(data=fc2, name='relu2', act_type='relu')
62+
fc3 = mx.symbol.FullyConnected(data=act2 , name='fc3', num_hidden=args.fc3)
63+
act3 = mx.symbol.Activation(data=fc3, name='relu3', act_type='relu')
64+
fc4 = mx.symbol.FullyConnected(data=act3 , name='fc4', num_hidden=args.fc4)
65+
act4 = mx.symbol.Activation(data=fc4, name='relu4', act_type='relu')
66+
fc5 = mx.symbol.FullyConnected(data=act4 , name='fc5', num_hidden=10)
67+
net = mx.symbol.SoftmaxOutput(data=fc5 , name='softmax')
68+
return net, [('data', (128, 1, 28, 28))], [('softmax_label', (128, ))]
69+
70+
def get_module(ctx, sym, provide_data, provide_label, batch_size=None, is_train=True, use_memonger=False):
71+
if use_memonger:
72+
sym = search_plan(sym, data=data_shapes)
73+
mod = mx.mod.Module(symbol=sym,
74+
data_names=[name for name, _ in provide_data],
75+
label_names=[name for name, _ in provide_label],
76+
context=ctx)
77+
if batch_size is not None:
78+
provide_data = [(name, (batch_size,) + shape[1:]) for name, shape in provide_data]
79+
provide_label = [(name, (batch_size,) + shape[1:]) for name, shape in provide_label]
80+
if is_train:
81+
mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=True, inputs_need_grad=False)
82+
else:
83+
mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=False, inputs_need_grad=False)
84+
85+
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
86+
mod.init_optimizer(optimizer='ccsgd',
87+
optimizer_params={
88+
'learning_rate': 0.0001,
89+
'momentum': 0.0,
90+
'wd': 0.0
91+
})
92+
return mod
93+
94+
95+
def benchmark(mod, dry_run=10, iterations=10):
96+
if len(mod._context) == 1:
97+
ctx = mod._context[0]
98+
else:
99+
ctx = mx.cpu()
100+
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
101+
label = [mx.nd.array(np.random.randint(1, 100, size=shape), ctx=ctx) for _, shape in mod.label_shapes]
102+
batch = mx.io.DataBatch(data, label)
103+
104+
# dry run
105+
for i in range(dry_run):
106+
mod.forward(batch, is_train=True)
107+
mod.backward()
108+
for output in mod.get_outputs(merge_multi_context=False)[0]:
109+
output.wait_to_read()
110+
mod.update()
111+
112+
t0 = time.clock()
113+
114+
profiler.profiler_set_state('run')
115+
# real run
116+
for i in range(iterations):
117+
mod.forward(batch, is_train=True)
118+
mod.backward()
119+
mod.update()
120+
for output in mod.get_outputs(merge_multi_context=False)[0]:
121+
output.wait_to_read()
122+
profiler.profiler_set_state('stop')
123+
124+
t1 = time.clock()
125+
return (t1 - t0)*1000.0 / iterations
126+
127+
128+
def executor(num_iteration):
129+
sym, provide_data, provide_label = get_symbol()
130+
ctx = [mx.gpu(0)]
131+
mod = get_module(ctx, sym, provide_data, provide_label, batch_size=128)
132+
return benchmark(mod, iterations=args.iter_num)
133+
134+
135+
args = parse_args()
136+
137+
if __name__ == '__main__':
138+
mx.profiler.profiler_set_config(mode='symbolic', filename=args.profile_filename)
139+
print('profile file save to {0}'.format(args.profile_filename))
140+
print('executor num_iteration: {0}'.format(args.iter_num))
141+
executor_time = executor(args.iter_num)
142+
print("executor {0} ms / iteration".format(executor_time))

example/profiler/profiler_matmul.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import mxnet as mx
2+
import argparse
3+
import os, sys
4+
import time
5+
import numpy as np
6+
7+
def parse_args():
8+
parser = argparse.ArgumentParser(description='Set network parameters for benchmark test.')
9+
parser.add_argument('--profile_filename', type=str, default='profile_matmul_20iter.json')
10+
parser.add_argument('--iter_num', type=int, default=100)
11+
parser.add_argument('--begin_profiling_iter', type=int, default=50)
12+
parser.add_argument('--end_profiling_iter', type=int, default=70)
13+
return parser.parse_args()
14+
15+
args = parse_args()
16+
17+
if __name__ == '__main__':
18+
mx.profiler.profiler_set_config(mode='symbolic', filename=args.profile_filename)
19+
print('profile file save to {0}'.format(args.profile_filename))
20+
21+
22+
A = mx.sym.Variable('A')
23+
B = mx.sym.Variable('B')
24+
C = mx.symbol.dot(A, B)
25+
26+
executor = C.simple_bind(mx.gpu(1), 'write', A=(4096, 4096), B=(4096, 4096))
27+
28+
a = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096))
29+
b = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096))
30+
31+
a.copyto(executor.arg_dict['A'])
32+
b.copyto(executor.arg_dict['B'])
33+
34+
flag = False
35+
print "execution begin"
36+
for i in range(args.iter_num):
37+
if i == args.begin_profiling_iter:
38+
t0 = time.clock()
39+
mx.profiler.profiler_set_state('run')
40+
if i == args.end_profiling_iter:
41+
t1 = time.clock()
42+
mx.profiler.profiler_set_state('stop')
43+
executor.forward()
44+
c = executor.outputs[0]
45+
c.wait_to_read()
46+
print "execution end"
47+
duration = t1 - t0
48+
print('duration: {0}s'.format(duration))
49+
print(' {0}ms/operator'.format(duration*1000/args.iter_num))

0 commit comments

Comments
 (0)