Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[tools/bandwidth] fix measure.py, add test_measure.py (#4076)
Browse files Browse the repository at this point in the history
* [tools/bandwidth] fix measure.py, add test_measure.py

* remove import argparse

* fix typos
  • Loading branch information
mli authored Dec 3, 2016
1 parent feb8762 commit 54b3dc4
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 86 deletions.
29 changes: 18 additions & 11 deletions example/image-classification/test_score.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
"""
test pretrained models
"""
import argparse
import mxnet as mx
from common import find_mxnet, modelzoo
from common.util import download_file
from score import score
import subprocess

def get_gpus():
try:
re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True)
except OSError:
return ''
gpus = [i for i in re.split('\n') if 'GPU' in i]
return ','.join([str(i) for i in range(len(gpus))])

def download_data():
download_file('http://data.mxnet.io/data/val-5k-256.rec', 'data/val-5k-256.rec')

def test_imagenet1k_resnet(args):
def test_imagenet1k_resnet(**kwargs):
models = ['imagenet1k-resnet-34',
'imagenet1k-resnet-50',
'imagenet1k-resnet-101',
Expand All @@ -19,28 +27,27 @@ def test_imagenet1k_resnet(args):
for (m, g) in zip(models, accs):
acc = mx.metric.create('acc')
(speed,) = score(model=m, data_val='data/val-5k-256.rec',
rgb_mean='0,0,0', metrics=acc, **vars(args))
rgb_mean='0,0,0', metrics=acc, **kwargs)
r = acc.get()[1]
print('testing %s, acc = %f, speed = %f img/sec' % (m, r, speed))
assert r > g and r < g + .1

def test_imagenet1k_inception_bn(args):
def test_imagenet1k_inception_bn(**kwargs):
acc = mx.metric.create('acc')
m = 'imagenet1k-inception-bn'
g = 0.72
(speed,) = score(model=m,
data_val='data/val-5k-256.rec',
rgb_mean='123.68,116.779,103.939', metrics=acc, **vars(args))
rgb_mean='123.68,116.779,103.939', metrics=acc, **kwargs)
r = acc.get()[1]
print('Tested %s acc = %f, speed = %f img/sec' % (m, r, speed))
assert r > g and r < g + .1

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='test score.py')
parser.add_argument('--gpus', type=str, default='0')
parser.add_argument('--batch-size', type=int, default=32)
args = parser.parse_args()
gpus = get_gpus()
assert gpus is not ''
batch_size = 32

download_data()
test_imagenet1k_resnet(args)
test_imagenet1k_inception_bn(args)
test_imagenet1k_resnet(gpus=gpus, batch_size=batch_size)
test_imagenet1k_inception_bn(gpus=gpus, batch_size=batch_size)
39 changes: 20 additions & 19 deletions tools/bandwidth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ measure.py --help` for more details.
- Use resnet 200 layers on GPU 0, 1, 2, and 3

```bash
~/mxnet/tools/bandwidth $ python measure.py --kv-store device --gpus 0,1 --network resnet --depth 200
INFO:root:Namespace(batch_size=128, data_shape='128,3,224,224', depth=200, disp_batches=1, gpus='0,1', kv_store='device', network='resnet', num_batches=5, num_classes=1000, optimizer='None', test_results=1)
~/mxnet/tools/bandwidth $ python measure.py --kv-store device --gpus 0,1 --network resnet --num-layers 200
INFO:root:Namespace(disp_batches=1, gpus='0,1', image_shape='3,224,224', kv_store='device', network='resnet', num_batches=5, num_classes=1000, num_layers=200, optimizer='None', test_results=1)
INFO:root:num of arrays = 205, total size = 257.991328 MB
INFO:root:iter 1, 0.023242 sec, 11.100222 GB/sec per gpu, error 0.000000
INFO:root:iter 2, 0.023106 sec, 11.165508 GB/sec per gpu, error 0.000000
Expand All @@ -47,8 +47,8 @@ because we do all-to-all communication.
- Use 8 GPUs, it saturates the single 16x link between GPU 0,1,2,3 and GPU 4,5,6,7.

```bash
~/mxnet/tools/bandwidth $ python measure.py --kv-store device --gpus 0,1,2,3,4,5,6,7 --network resnet --depth 200
INFO:root:Namespace(batch_size=128, data_shape='128,3,224,224', depth=200, disp_batches=1, gpus='0,1,2,3,4,5,6,7', kv_store='device', network='resnet', num_batches=5, num_classes=1000, optimizer='None', test_results=1)
~/mxnet/tools/bandwidth $ python measure.py --kv-store device --gpus 0,1,2,3,4,5,6,7 --network resnet --num-layers 200
INFO:root:Namespace(disp_batches=1, gpus='0,1,2,3,4,5,6,7', image_shape='3,224,224', kv_store='device', network='resnet', num_batches=5, num_classes=1000, num_layers=200, optimizer='None', test_results=1)
INFO:root:num of arrays = 205, total size = 257.991328 MB
INFO:root:iter 1, 0.102321 sec, 4.412429 GB/sec per gpu, error 0.000000
INFO:root:iter 2, 0.100345 sec, 4.499330 GB/sec per gpu, error 0.000000
Expand All @@ -61,8 +61,8 @@ INFO:root:iter 5, 0.100774 sec, 4.480169 GB/sec per gpu, error 0.000000
between all GPUs and the CPU.

```bash
~/mxnet/tools/bandwidth $ python measure.py --kv-store local --gpus 0,1,2,3,4,5,6,7 --network resnet --depth 200
INFO:root:Namespace(batch_size=128, data_shape='128,3,224,224', depth=200, disp_batches=1, gpus='0,1,2,3,4,5,6,7', kv_store='local', network='resnet', num_batches=5, num_classes=1000, optimizer='None', test_results=1)
~/mxnet/tools/bandwidth $ python measure.py --kv-store local --gpus 0,1,2,3,4,5,6,7 --network resnet --num-layers 200
INFO:root:Namespace(disp_batches=1, gpus='0,1,2,3,4,5,6,7', image_shape='3,224,224', kv_store='local', network='resnet', num_batches=5, num_classes=1000, num_layers=200, optimizer='None', test_results=1)
INFO:root:num of arrays = 205, total size = 257.991328 MB
INFO:root:iter 1, 0.290164 sec, 1.555964 GB/sec per gpu, error 0.000000
INFO:root:iter 2, 0.293963 sec, 1.535856 GB/sec per gpu, error 0.000000
Expand All @@ -71,17 +71,18 @@ INFO:root:iter 4, 0.290657 sec, 1.553325 GB/sec per gpu, error 0.000000
INFO:root:iter 5, 0.290799 sec, 1.552567 GB/sec per gpu, error 0.000000
```

- Finally we change to VGG and also run the `sgd` optimizor
- Finally we change to Inception-v3 which requires input image size to be `3*299*299`, and also run the `sgd` optimizor

```bash
~/mxnet/tools/bandwidth $ python measure.py --kv-store device --gpus 0,1,2,3,4,5,6,7 --network vgg --optimizer sgd
INFO:root:Namespace(batch_size=128, data_shape='128,3,224,224', depth=152, disp_batches=1, gpus='0,1,2,3,4,5,6,7', kv_store='device', network='vgg', num_batches=5, num_classes=1000, optimizer='sgd', test_results=1)
INFO:root:num of arrays = 22, total size = 531.453344 MB
INFO:root:iter 1, 0.525208 sec, 1.770810 GB/sec per gpu, error 0.000000
INFO:root:iter 2, 0.524052 sec, 1.774715 GB/sec per gpu, error 0.000000
INFO:root:iter 3, 0.524732 sec, 1.772416 GB/sec per gpu, error 0.000000
INFO:root:iter 4, 0.527117 sec, 1.764396 GB/sec per gpu, error 0.000000
INFO:root:iter 5, 0.520293 sec, 1.787538 GB/sec per gpu, error 0.000000
~/mxnet/tools/bandwidth $ python measure.py --kv-store device --gpus 0,1,2,3,4,5,6,7 --image-shape 3,299,299 --network inception-v3 --optimizer sgd
libdc1394 error: Failed to initialize libdc1394
INFO:root:Namespace(disp_batches=1, gpus='0,1,2,3,4,5,6,7', image_shape='3,299,299', kv_store='device', network='inception-v3', num_batches=5, num_classes=1000, num_layers=152, optimizer='sgd', test_results=1)
INFO:root:num of arrays = 96, total size = 95.200544 MB
INFO:root:iter 1, 0.086527 sec, 1.925424 GB/sec per gpu, error 0.000000
INFO:root:iter 2, 0.057934 sec, 2.875700 GB/sec per gpu, error 0.000000
INFO:root:iter 3, 0.055442 sec, 3.004967 GB/sec per gpu, error 0.000000
INFO:root:iter 4, 0.055579 sec, 2.997555 GB/sec per gpu, error 0.000000
INFO:root:iter 5, 0.055107 sec, 3.023220 GB/sec per gpu, error 0.000000
```

### Multiple GPU machines
Expand All @@ -98,8 +99,8 @@ For more than one machines, we can replace `hosts` with the actual machine IPs
line by line. Then launch it by

```bash
~/mxnet/tools/bandwidth $ python ../launch.py -H hosts -n 1 python measure.py --kv-store dist_device_sync --gpus 0,1,2,3,4,5,6,7 --network resnet --depth 200
INFO:root:Namespace(batch_size=128, data_shape='128,3,224,224', depth=200, disp_batches=1, gpus='0,1,2,3,4,5,6,7', kv_store='dist_device_sync', network='resnet', num_batches=5, num_classes=1000, optimizer='None', test_results=1)
~/mxnet/tools/bandwidth $ python ../launch.py -H hosts -n 1 python measure.py --kv-store dist_device_sync --gpus 0,1,2,3,4,5,6,7 --network resnet --num-layers 200
INFO:root:Namespace(disp_batches=1, gpus='0,1,2,3,4,5,6,7', image_shape='3,224,224', kv_store='dist_device_sync', network='resnet', num_batches=5, num_classes=1000, num_layers=200, optimizer='None', test_results=1)
INFO:root:num of arrays = 205, total size = 257.991328 MB
INFO:root:iter 1, 0.295398 sec, 1.528395 GB/sec per gpu, error 0.000000
INFO:root:iter 2, 0.303159 sec, 1.489267 GB/sec per gpu, error 0.000000
Expand All @@ -113,8 +114,8 @@ harms the performance. We can slightly improve the performance using more than
1 server nodes:

```bash
~/mxnet/tools/bandwidth $ python ../launch.py -H hosts -n 1 -s 4 python measure.py --kv-store dist_device_sync --gpus 0,1,2,3,4,5,6,7 --network resnet --depth 200
INFO:root:Namespace(batch_size=128, data_shape='128,3,224,224', depth=200, disp_batches=1, gpus='0,1,2,3,4,5,6,7', kv_store='dist_device_sync', network='resnet', num_batches=5, num_classes=1000, optimizer='None', test_results=1)
~/mxnet/tools/bandwidth $ python ../launch.py -H hosts -n 1 -s 4 python measure.py --kv-store dist_device_sync --gpus 0,1,2,3,4,5,6,7 --network resnet --num-layers 200
INFO:root:Namespace(disp_batches=1, gpus='0,1,2,3,4,5,6,7', image_shape='3,224,224', kv_store='dist_device_sync', network='resnet', num_batches=5, num_classes=1000, num_layers=200, optimizer='None', test_results=1)
INFO:root:num of arrays = 205, total size = 257.991328 MB
INFO:root:iter 1, 0.233309 sec, 1.935137 GB/sec per gpu, error 0.000000
INFO:root:iter 2, 0.253864 sec, 1.778453 GB/sec per gpu, error 0.000000
Expand Down
90 changes: 34 additions & 56 deletions tools/bandwidth/measure.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os, sys
curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, os.path.join(curr_path, "../../python"))
sys.path.insert(0, os.path.join(curr_path, "../../example/image-classification"))
sys.path.insert(0, os.path.join(curr_path, "../../example/image-classification/symbol"))
import mxnet as mx
import logging
import argparse
import time
import numpy as np
from importlib import import_module
from collections import namedtuple

logger = logging.getLogger()
logger.setLevel(logging.INFO)
Expand All @@ -17,20 +19,18 @@ def parse_args():
help='the neural network to test')
parser.add_argument('--gpus', type=str, default='0,1',
help='the gpus to be used, e.g "0,1,2,3"')
parser.add_argument('--depth', type=int, default=152,
help='the depth of network, only valid for resnet')
parser.add_argument('--num-layers', type=int, default=152,
help='number of layers, can be used for resnet')
parser.add_argument('--kv-store', type=str, default='device',
help='the kvstore type')
parser.add_argument('--batch-size', type=int, default=128,
help='batch size. should not affect the results')
parser.add_argument('--num-batches', type=int, default=5,
help='number of batches to run')
parser.add_argument('--disp-batches', type=int, default=1,
help='show averaged results for every n batches')
parser.add_argument('--test-results', type=int, default=1,
help='if or not evalute the results correctness')
parser.add_argument('--data-shape', type=str, default='128,3,224,224',
help='input data shape')
parser.add_argument('--image-shape', type=str, default='3,224,224',
help='input images shape')
parser.add_argument('--num-classes', type=int, default=1000,
help='number of classes')
parser.add_argument('--optimizer', type=str, default='None',
Expand All @@ -39,33 +39,6 @@ def parse_args():
logging.info(args)
return args

def get_resnet(args):
resnet_path = os.path.join(curr_path, "./ResNet")
if not os.path.isdir(resnet_path):
os.system("git clone https://github.com/tornadomeet/ResNet")
sys.path.insert(0, resnet_path)
from symbol_resnet import resnet
if args.depth == 18:
units = [2, 2, 2, 2]
elif args.depth == 34:
units = [3, 4, 6, 3]
elif args.depth == 50:
units = [3, 4, 6, 3]
elif args.depth == 101:
units = [3, 4, 23, 3]
elif args.depth == 152:
units = [3, 8, 36, 3]
elif args.depth == 200:
units = [3, 24, 36, 3]
else:
raise ValueError("no experiments done on detph {}, you can do it youself".format(args.depth))

filter_list=[64, 256, 512, 1024, 2048] if args.depth >=50 else [64, 64, 128, 256, 512]
bottle_neck = True if args.depth >= 50 else False
symbol = resnet(units=units, num_stage=4, filter_list=filter_list,
num_class=args.num_classes, data_type="imagenet", bottle_neck=bottle_neck, bn_mom=.9, workspace=512)
return symbol

def get_shapes(symbol, data_shape):
arg_name = symbol.list_arguments()
arg_shape, _, _ = symbol.infer_shape(data=data_shape)
Expand All @@ -80,25 +53,22 @@ def error(gpu_res, cpu_res):
res /= sum([np.sum(np.abs(g.asnumpy())) for g in cpu_res])
return res

def run():
args = parse_args();
def run(network, optimizer, gpus, kv_store, image_shape, disp_batches,
num_batches, test_results, **kwargs):
# create kvstore and optimizer
devs = [mx.gpu(int(i)) for i in args.gpus.split(',')]
kv = mx.kv.create(args.kv_store)
if args.optimizer == 'None':
optimizer = None
devs = [mx.gpu(int(i)) for i in gpus.split(',')]
kv = mx.kv.create(kv_store)
if optimizer is None or optimizer == 'None':
opt = None
else:
optimizer = mx.optimizer.Optimizer.create_optimizer(args.optimizer)
updater = mx.optimizer.get_updater(mx.optimizer.Optimizer.create_optimizer(args.optimizer))
kv.set_optimizer(optimizer)
opt = mx.optimizer.Optimizer.create_optimizer(optimizer)
kv.set_optimizer(opt)
updater = mx.optimizer.get_updater(mx.optimizer.Optimizer.create_optimizer(optimizer))

# create network
if args.network == 'resnet':
symbol = get_resnet(args)
else:
import importlib
symbol = importlib.import_module('symbol_' + args.network).get_symbol(args.num_classes)
data_shape = tuple([int(s) for s in args.data_shape.split(',')])
symbol = import_module(network).get_symbol(image_shape=image_shape, **kwargs)
# a fake batch size 32, which does not affect the results
data_shape = (32,) + tuple([int(s) for s in image_shape.split(',')])
shapes = get_shapes(symbol, data_shape)

size = float(sum([reduce(lambda x,y : x*y, s, 1) for s in shapes])) * 4 / 1e6
Expand All @@ -114,7 +84,10 @@ def run():
cpu_grads = [mx.nd.array(sum([g.asnumpy() for g in gs]))*kv.num_workers for gs in grads_val]
cpu_weights = [mx.nd.zeros(s) for s in shapes]
toc = 0
for b in range(0, args.num_batches+1):

Results = namedtuple('Results', ['iter', 'time', 'bandwidth', 'error'])
res = []
for b in range(0, num_batches+1):
tic = time.time()
for i,g in enumerate(grads):
kv.push(i, g, i)
Expand All @@ -125,8 +98,8 @@ def run():
for w in ws:
w.wait_to_read()
toc += time.time() - tic
if args.test_results:
if optimizer == None:
if test_results:
if opt == None:
err = error(weights, cpu_grads)
else:
for i, wg in enumerate(zip(cpu_weights, cpu_grads)):
Expand All @@ -135,13 +108,18 @@ def run():
else:
err = -1

if b % args.disp_batches == 0:
toc /= args.disp_batches
if b % disp_batches == 0:
toc /= disp_batches
if b != 0:
# 0 is used for warmup, ignored
r = Results(iter=b, time=toc, error=err,
bandwidth=size*2*(len(devs)-1)/len(devs)/toc/1e3)
logging.info('iter %d, %f sec, %f GB/sec per gpu, error %f' % (
b, toc, size*2*(len(devs)-1)/len(devs)/toc/1e3, err))
r.iter, r.time, r.bandwidth, r.error))
res.append(r)
toc = 0
return res

if __name__ == "__main__":
run()
args = parse_args();
run(**vars(args))
29 changes: 29 additions & 0 deletions tools/bandwidth/test_measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
test measure.py
"""
from measure import run
import subprocess
import logging
def get_gpus():
try:
re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True)
except OSError:
return ''
gpus = [i for i in re.split('\n') if 'GPU' in i]
return ','.join([str(i) for i in range(len(gpus))])

def test_measure(**kwargs):
logging.info(kwargs)
res = run(image_shape='3,224,224', num_classes=1000,
num_layers=50, disp_batches=2, num_batches=2, test_results=1, **kwargs)
assert len(res) == 1
assert res[0].error < 1e-4

if __name__ == '__main__':
gpus = get_gpus()
assert gpus is not ''
test_measure(gpus=gpus, network='alexnet', optimizer=None, kv_store='device')
test_measure(gpus=gpus, network='resnet', optimizer='sgd', kv_store='device')
test_measure(gpus=gpus, network='inception-bn', optimizer=None, kv_store='local')
test_measure(gpus=gpus, network='resnet', optimizer=None, kv_store='local')
test_measure(gpus=gpus, network='resnet', optimizer='sgd', kv_store='local')

0 comments on commit 54b3dc4

Please sign in to comment.