Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
MXNet Profiler (#3163)
Browse files Browse the repository at this point in the history
* NNVM Refactor (#3194)

* Init nnvm change

* temp checkin

* Move TShape to NNVM

* Redirect Symbolic API to NNVM

* Add Op Prop Adapter

* Finish migrate in shape infer

* Pass all symbolic test

* temp commit

* enable aux data

* [EXEC] Basic version of exec for forward only

* [EXEC] Enable most optimizations, still wait grad and context

* fix legacy op with latest one

* Update NNVM NodeRef

* Adapt to newer interface

* ALl registry of backop is complete

* temp commit

* Hack finish backward pass

* [EXEC] One day pass

* [EXEC] Pass all operator unittest

* [EXEC] enable model parallel

* Fully pass all legacy tests

* Remove legacy symbolic code

* update news

* Make travis compile

* Fix python3

* Update viz module to new json format

* [NNVM] Imperative Invoke (#3208)

* [Engine] Deduplicate Variable Util

* [NNVM] NNVM Imperative Invoke

* [NNVM] Imperative improve speed

* fix

* fix

* [scala] link libnnvm.a (#3214)

* [PYTHON] Optional Cython Module for Symbols (#3242)

* [CYTHON] Checkin cython enhancement

* fix lint

* [DOC] Move common doc to base

* [EXEC] Support fcompute (#3249)

* [EXEC] Support fcompute

* Fix lint

* fix lint

* [OP] Add alias support (#3261)

* Fix path in setup.py (#3276)

* Fix path in setup.py

* revert the nnvm version

* [WIP] Element wise op refactor (#3245)

* [OPERATOR] Refactor Unary Ops

* [OPERATOR] Refactor Binary Scalar Ops

* Use alias

* update nnvm version (#3290)

* Fix breaking changes after pull master (#3291)

* [CYTHON] Cython module for NDArray (#3292)

* [NDARRAY] Cython module for ndarray

* More strict tests

* [NNVM] change of attr to set_attr (#3303)

* Update run_test.sh

* add nnvm cmake with windows (#3255)

* [WIP] binary broadcast wip (#3301)

* [WIP] binary broadcast wip

[OPERATOR] Binary Broadcast ops

fix lint

lint

fix

max and min

update submodule

before removing reduce axis

broad cast reduce ops

* update

* fix

* fix warning

* fix

* x (#3308)

* [IO] Python based ImageIter and Augumenter (#3227)

* [IO] Python based ImageIter and Augumenter

* fix

* fix

* fix

* [OPT] NNVM Optimizer (#3314)

* fix cpython in windows (#3309)

* Add Mathematical functions (#3317)

* fix image io

* add hypot degrees radians cosh sinh tanh arcsinh arccosh arctanh (#3335)

* add recent examples, collect some missing tutorials (#3340)

* Improving docs & utilities for distributed training example. (#3341)

* add init dict

* disable SSE for arm hardware e.g. Raspberry Pi (#3346)

* Add channel_ to Shape2D calculation (#3181)

* Add channel_ to Shape2D calculation

* scalapkg, add example multitask (#3186)

* RNN cell demo with ptb LSTM language model (#3197)

* rnn-cell demo (push to server for testing)

* a running example with cuDNN RNN cell

* Bulk lint fix (#3211)

* [TENSOR] Add FlatTo1D for all elementwise ops (#3238)

* Fix little bug on context (#3202)

* add PennTreeBank Language Model using lstm model in R (#2659)

* Add function 'print_summary' and some revise (#3161)

* Add function 'print_summary' and some revise

Add function 'print_summary' for print detail information of network, and format argument was add in 'plot_network'.
You can use 'print_summary' like:
"""
net = get_symbol(1000)
shape = {'softmax_label': (64, 12), 'data': (64, 3, 224, 224)}
mx.viz.print_summary(net, shape=shape)
"""
If without shape, the number of arguments would be nonsense currently.

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Update visualization.py

* Added my CmakeLists.txt for caffe plugin, etc.

* Revert "fix travis scala test config" (#3246)

This reverts parts of commit 3e15f62.
Reenables testing the Julia bindings

* [Scala] Code generation for Symbol (#3217)

[scala] auto-generate Symbol functions

* fix spelling errors (#3258)

Also align grammar and punctuation in short descriptions of features

* fix typo in run_test.sh (#3260)

* Copy slice along arbitrary axis (#3259)

* rnn-cell demo (push to server for testing)

* a running example with cuDNN RNN cell

* add copyslice along arbitrary axis for NDArray

* copy_slice_to as an ndarray operator

* Python interface to the _copy_slice_to operator

* fix lint error

* Enable concatenation for dim-1 vectors (#3264)

* fix PReLU backward computing (#3277)

* Add `reverse` option in Reshape (#3280)

* add scala example, end2end neural-style (#3267)

add scala example, end2end neural-style

* Improve multi-GPU performance (#3241)

* update kvstore

* update model.py

* bandwith tool

* update readme

* tiny

* fix lint

* fix batch size of dist_device_sync

* fix

* fix perf problem of kvstore when only using a single device

* roll back to previous strategy how to choose update_on_kvsotre

* add an optionl MXNET_ENABLE_GPU_P2P to control whether or not use p2p

* update dmlccore (#3293)

* Fix newer version of gtest and cpptest (#3294)

* when set use_global_stats then do not use cudnn (#3289)

* when set use_global_stats then do not use cudnn

* fix batch norm with use_global_stats

* Fix req+reserve_space in cudnn_rnn (#3274)

Fix req

Fix reserve_space

Allocate reserve_space using Storage

* add cudnn off option in Convolution (#3270)

* add support for building on power (#3302)

* add recent examples, collect some missing tutorials (#3340)

* CMake for caffe plugin

* Fix metric & im2rec.py

* [Scala] Nnvm ops for NDArray & Symbol (#3361)

* [scala] nnvm op support

* [scala] remove unused codes

* fix scala native code style

* [R] Fix the R interface (#3334)

* [R] Fix the R interface. remove man

* Fix BN legacy issue

* Locate compiled library on Windows (#3369)

* Fix metric & im2rec.py (#3375)

image io fix

* Update legacy op FBackwardInGradIndex (#3376)

* Update legacy op FBackwardInGradIndex

* fix test

* Fix for LRN Layer (#3366)

* fixed cpu forward bug

* added out_data[lrn_enum::kOut] as backward req.

* removed lint

* removed duplicate out_data[lrn_enum::kTmpNorm],

* removed inplace option

* add backward index

* include some special functions (#3337)

- gamma
- gammaln
- log1p
- expm1

* fix kv build (#3385)

* initial profiler branch based on dmlc/mxnet:nnvm

* [profiler] add profiler & modify engine API

* [profiler] add USE_PROFILER compile flag & modify code for changed engine api

* [profiler] add c_api interface & modify graph_executor

* [profiler] add python api

* [profiler] typo & lint error

* [profiler] reduce overhead & add PROFIELR_MESSAGE_FUNCNAME macro

* [profiler] remove profiling argument from PushSync/PushAsync

* [profiler] refactor profiler.h/.cc

* [profiler] improve readability

* [profiler] typo && add TODO comment

* [profiler] fix ndarray op name & add WaitForVar back

* [profiler] add example/profiler/profiler_ndarray.py

* [profiler] fix memleak by using op->name

* [profiler] fix lint

* [profiler] fix lint
  • Loading branch information
Ziheng Jiang authored and piiswrong committed Nov 30, 2016
1 parent 09ff15b commit acdcfd1
Show file tree
Hide file tree
Showing 31 changed files with 1,261 additions and 75 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ else
NVCCFLAGS = -std=c++11 -Xcompiler -D_FORCE_INLINES -g -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
endif

# CFLAGS for profiler
ifeq ($(USE_PROFILER), 1)
CFLAGS += -DMXNET_USE_PROFILER=1
endif

ifndef LINT_LANG
LINT_LANG="all"
endif
Expand Down
142 changes: 142 additions & 0 deletions example/profiler/profiler_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import mxnet as mx
import argparse
import os, sys
import time
import numpy as np
from mxnet import profiler


def parse_args():
parser = argparse.ArgumentParser(description='Set network parameters for benchmark test.')
parser.add_argument('--profile_filename', type=str, default='profile_executor_5iter.json')
parser.add_argument('--iter_num', type=int, default=5)
parser.add_argument('--fc1', type=int, default=128)
parser.add_argument('--fc2', type=int, default=128)
parser.add_argument('--fc3', type=int, default=128)
parser.add_argument('--fc4', type=int, default=128)
return parser.parse_args()


def _download(data_dir):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists('train-images-idx3-ubyte')) or \
(not os.path.exists('train-labels-idx1-ubyte')) or \
(not os.path.exists('t10k-images-idx3-ubyte')) or \
(not os.path.exists('t10k-labels-idx1-ubyte')):
os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip")
os.system("unzip -u mnist.zip; rm mnist.zip")
os.chdir("..")


def get_data(data_shape):
data_dir = "mnist/"
batch_size = 128
if '://' not in data_dir:
_download(data_dir)

train = mx.io.MNISTIter(
image = data_dir + "train-images-idx3-ubyte",
label = data_dir + "train-labels-idx1-ubyte",
input_shape = data_shape,
batch_size = batch_size,
shuffle = True,
)

val = mx.io.MNISTIter(
image = data_dir + "t10k-images-idx3-ubyte",
label = data_dir + "t10k-labels-idx1-ubyte",
input_shape = data_shape,
batch_size = batch_size,
)

return (train, val)

def get_symbol():
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=args.fc1)
act1 = mx.symbol.Activation(data=fc1, name='relu1', act_type='relu')
fc2 = mx.symbol.FullyConnected(data=act1 , name='fc2', num_hidden=args.fc2)
act2 = mx.symbol.Activation(data=fc2, name='relu2', act_type='relu')
fc3 = mx.symbol.FullyConnected(data=act2 , name='fc3', num_hidden=args.fc3)
act3 = mx.symbol.Activation(data=fc3, name='relu3', act_type='relu')
fc4 = mx.symbol.FullyConnected(data=act3 , name='fc4', num_hidden=args.fc4)
act4 = mx.symbol.Activation(data=fc4, name='relu4', act_type='relu')
fc5 = mx.symbol.FullyConnected(data=act4 , name='fc5', num_hidden=10)
net = mx.symbol.SoftmaxOutput(data=fc5 , name='softmax')
return net, [('data', (128, 1, 28, 28))], [('softmax_label', (128, ))]

def get_module(ctx, sym, provide_data, provide_label, batch_size=None, is_train=True, use_memonger=False):
if use_memonger:
sym = search_plan(sym, data=data_shapes)
mod = mx.mod.Module(symbol=sym,
data_names=[name for name, _ in provide_data],
label_names=[name for name, _ in provide_label],
context=ctx)
if batch_size is not None:
provide_data = [(name, (batch_size,) + shape[1:]) for name, shape in provide_data]
provide_label = [(name, (batch_size,) + shape[1:]) for name, shape in provide_label]
if is_train:
mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=True, inputs_need_grad=False)
else:
mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=False, inputs_need_grad=False)

mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
mod.init_optimizer(optimizer='ccsgd',
optimizer_params={
'learning_rate': 0.0001,
'momentum': 0.0,
'wd': 0.0
})
return mod


def benchmark(mod, dry_run=10, iterations=10):
if len(mod._context) == 1:
ctx = mod._context[0]
else:
ctx = mx.cpu()
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
label = [mx.nd.array(np.random.randint(1, 100, size=shape), ctx=ctx) for _, shape in mod.label_shapes]
batch = mx.io.DataBatch(data, label)

# dry run
for i in range(dry_run):
mod.forward(batch, is_train=True)
mod.backward()
for output in mod.get_outputs(merge_multi_context=False)[0]:
output.wait_to_read()
mod.update()

t0 = time.clock()

profiler.profiler_set_state('run')
# real run
for i in range(iterations):
mod.forward(batch, is_train=True)
mod.backward()
mod.update()
for output in mod.get_outputs(merge_multi_context=False)[0]:
output.wait_to_read()
profiler.profiler_set_state('stop')

t1 = time.clock()
return (t1 - t0)*1000.0 / iterations


def executor(num_iteration):
sym, provide_data, provide_label = get_symbol()
ctx = [mx.gpu(0)]
mod = get_module(ctx, sym, provide_data, provide_label, batch_size=128)
return benchmark(mod, iterations=args.iter_num)


args = parse_args()

if __name__ == '__main__':
mx.profiler.profiler_set_config(mode='symbolic', filename=args.profile_filename)
print('profile file save to {0}'.format(args.profile_filename))
print('executor num_iteration: {0}'.format(args.iter_num))
executor_time = executor(args.iter_num)
print("executor {0} ms / iteration".format(executor_time))
49 changes: 49 additions & 0 deletions example/profiler/profiler_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import mxnet as mx
import argparse
import os, sys
import time
import numpy as np

def parse_args():
parser = argparse.ArgumentParser(description='Set network parameters for benchmark test.')
parser.add_argument('--profile_filename', type=str, default='profile_matmul_20iter.json')
parser.add_argument('--iter_num', type=int, default=100)
parser.add_argument('--begin_profiling_iter', type=int, default=50)
parser.add_argument('--end_profiling_iter', type=int, default=70)
return parser.parse_args()

args = parse_args()

if __name__ == '__main__':
mx.profiler.profiler_set_config(mode='symbolic', filename=args.profile_filename)
print('profile file save to {0}'.format(args.profile_filename))


A = mx.sym.Variable('A')
B = mx.sym.Variable('B')
C = mx.symbol.dot(A, B)

executor = C.simple_bind(mx.gpu(1), 'write', A=(4096, 4096), B=(4096, 4096))

a = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096))
b = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096))

a.copyto(executor.arg_dict['A'])
b.copyto(executor.arg_dict['B'])

flag = False
print "execution begin"
for i in range(args.iter_num):
if i == args.begin_profiling_iter:
t0 = time.clock()
mx.profiler.profiler_set_state('run')
if i == args.end_profiling_iter:
t1 = time.clock()
mx.profiler.profiler_set_state('stop')
executor.forward()
c = executor.outputs[0]
c.wait_to_read()
print "execution end"
duration = t1 - t0
print('duration: {0}s'.format(duration))
print(' {0}ms/operator'.format(duration*1000/args.iter_num))
Loading

0 comments on commit acdcfd1

Please sign in to comment.