Skip to content

Commit a81a167

Browse files
tmoreau89tqchen
authored andcommitted
[UTILS, DOC] Use TVM file downloading utility, conv2d tutorial (apache#48)
1 parent 5c8177b commit a81a167

File tree

9 files changed

+156
-108
lines changed

9 files changed

+156
-108
lines changed

vta/examples/resnet18/pynq/README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,20 @@
22

33
Follow the first two parts of the [Installation Guide](../../../docs/how_to/install.md) to make sure that the VTA python libraries are installed, and that the RPC server is running on the Pynq FPGA dev board.
44

5-
Simply run the following python script:
5+
We recommend leaving the `config.json` to its default parameterization (of course you can change the target between "sim" and "pynq").
6+
7+
Simply run the example program. We rely on pickle to store parameters which now only works with python2.
68
```bash
7-
python imagenet_predict.py
9+
python2 imagenet_predict.py
810
```
911

10-
This will run imagenet classification using the ResNet18 architecture on a VTA design that performs 8-bit integer inference, to perform classification on a cat image `cat.jpg`.
12+
The script will first download the following files into `_data/` directory:
13+
* `cat.jpg` which provides a test sample for the ImageNet classifier
14+
* `quantize_graph.json` which describes the NNVM graph of the 8-bit ResNet-18
15+
* `quantize_params.plk` which contains the network parameters
16+
* `synset.txt` which contains the ImageNet categories
17+
18+
Next, it will run imagenet classification using the ResNet18 architecture on a VTA design that performs 8-bit integer inference, to perform classification on a cat image `cat.jpg`.
1119

1220
The script reports runtime measured on the Pynq board (in seconds), and the top-1 result category:
1321
```

vta/examples/resnet18/pynq/imagenet_predict.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
# some standard imports
22
import nnvm
33
import tvm
4-
from nnvm.compiler import graph_attr
54
import vta
65
import vta.testing
76
import os
87
import numpy as np
9-
from PIL import Image
108
import pickle
119
import json
1210
import logging
13-
import wget
11+
12+
from PIL import Image
13+
from nnvm.compiler import graph_attr
1414
from tvm.contrib import graph_runtime, rpc, util
15+
from tvm.contrib.download import download
1516

1617
bfactor = 1
1718
cfactor = 16
@@ -20,15 +21,20 @@
2021
debug_fpga_only = False
2122

2223
# Obtain model and hardware files (they're too large to check-in)
24+
# Download them into _data dir
25+
data_dir = "_data/"
2326
url = "https://homes.cs.washington.edu/~moreau/media/vta/"
2427
TEST_FILE = 'cat.jpg'
2528
CATEG_FILE = 'synset.txt'
2629
RESNET_GRAPH_FILE = 'quantize_graph.json'
2730
RESNET_PARAMS_FILE = 'quantize_params.pkl'
31+
# Create data dir
32+
if not os.path.exists(data_dir):
33+
os.makedirs(data_dir)
34+
# Download files
2835
for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE]:
2936
if not os.path.isfile(file):
30-
print ("Downloading {}".format(file))
31-
wget.download(url+file)
37+
download(os.path.join(url, file), os.path.join(data_dir, file))
3238

3339
if verbose:
3440
logging.basicConfig(level=logging.DEBUG)
@@ -40,8 +46,8 @@
4046
if vta.get_env().TARGET == "sim":
4147
target_host = "llvm"
4248

43-
synset = eval(open(os.path.join(CATEG_FILE)).read())
44-
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
49+
synset = eval(open(os.path.join(data_dir, CATEG_FILE)).read())
50+
image = Image.open(os.path.join(data_dir, TEST_FILE)).resize((224, 224))
4551

4652
def transform_image(image):
4753
image = np.array(image) - np.array([123., 117., 104.])
@@ -88,9 +94,9 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
8894
import nnvm.compiler
8995
np.random.seed(0)
9096
sym = nnvm.graph.load_json(
91-
open(os.path.join(RESNET_GRAPH_FILE)).read())
97+
open(os.path.join(data_dir, RESNET_GRAPH_FILE)).read())
9298
params = pickle.load(
93-
open(os.path.join(RESNET_PARAMS_FILE)))
99+
open(os.path.join(data_dir, RESNET_PARAMS_FILE), 'rb'))
94100

95101
shape_dict = {"data": x.shape}
96102
dtype_dict = {"data": 'float32'}

vta/python/vta/bitstream.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@
22
from __future__ import absolute_import as _abs
33

44
import os
5-
import urllib
5+
import sys
6+
7+
from tvm.contrib.download import download
68
from .environment import get_env
79

10+
if sys.version_info >= (3,):
11+
import urllib.error as urllib2
12+
else:
13+
import urllib2
14+
815
# bitstream repo
916
BITSTREAM_URL = "https://github.com/uwsaml/vta-distro/raw/master/bitstreams/"
1017

@@ -41,15 +48,25 @@ def download_bitstream():
4148
url = os.path.join(BITSTREAM_URL, env.TARGET)
4249
url = os.path.join(url, env.HW_VER)
4350
url = os.path.join(url, env.BITSTREAM)
44-
# Check that the bitstream is accessible from the server
45-
if urllib.urlopen(url).getcode() == 404:
46-
# Raise error - the solution when this happens it to build your own bitstream and add it
47-
# to your VTA_CACHE_PATH
48-
raise RuntimeError(
49-
"Error: {} is not available. It appears that this configuration has not been built."
50-
.format(url))
51-
else:
52-
urllib.urlretrieve(url, bit)
53-
success = True
51+
52+
try:
53+
download(url, bit)
54+
except urllib2.HTTPError as err:
55+
if err.code == 404:
56+
raise RuntimeError(
57+
# Raise error - the solution when this happens it to build your
58+
# own bitstream and add it to your $VTA_CACHE_PATH
59+
"{} is not available. It appears that this configuration \
60+
bistream has not been cached. Please compile your own bitstream (see hardware \
61+
compilation guide to get Xilinx toolchains setup) and add it to your \
62+
$VTA_CACHE_PATH. Alternatively edit your config.json back to its default \
63+
settings. You can see the list of available bitstreams under {}"
64+
.format(url, BITSTREAM_URL))
65+
else:
66+
raise RuntimeError(
67+
# This could happen when trying to access the URL behind a proxy
68+
"Something went wrong when trying to access {}. Check your \
69+
internet connection or proxy settings."
70+
.format(url))
5471

5572
return success

vta/python/vta/testing/util.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,34 @@ def run(run_func):
1515
"""
1616
env = get_env()
1717

18-
# Run on local sim rpc if necessary
19-
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
20-
if local_rpc:
21-
env.TARGET = "sim"
22-
remote = rpc.connect("localhost", local_rpc)
23-
run_func(env, remote)
24-
else:
25-
# run on simulator
26-
if simulator.enabled():
27-
env.TARGET = "sim"
18+
if env.TARGET == "sim":
19+
20+
# Talk to local RPC if necessary to debug RPC server.
21+
# Compile vta on your host with make at the root.
22+
# Make sure TARGET is set to "sim" in the config.json file.
23+
# Then launch the RPC server on the host machine
24+
# with ./apps/pynq_rpc/start_rpc_server.sh
25+
# Set your VTA_LOCAL_SIM_RPC environment variable to
26+
# the port it's listening to, e.g. 9090
27+
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
28+
if local_rpc:
29+
remote = rpc.connect("localhost", local_rpc)
30+
run_func(env, remote)
31+
else:
32+
# Make sure simulation library exists
33+
# If this fails, build vta on host (make)
34+
# with TARGET="sim" in the json.config file.
35+
assert simulator.enabled()
2836
run_func(env, rpc.LocalSession())
2937

30-
# Run on PYNQ if env variable exists
31-
host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
32-
if host:
33-
env.TARGET = "pynq"
34-
port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
35-
port = int(port)
36-
remote = rpc.connect(host, port)
37-
run_func(env, remote)
38+
elif env.TARGET == "pynq":
39+
40+
# Run on PYNQ if env variable exists
41+
host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
42+
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None))
43+
if host and port:
44+
remote = rpc.connect(host, port)
45+
run_func(env, remote)
46+
else:
47+
raise RuntimeError(
48+
"Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables")

vta/tests/python/integration/test_benchmark_gemm.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def run_gemm_packed(env, remote, batch_size, channel, block):
1818
channel // env.BLOCK_OUT,
1919
env.BATCH,
2020
env.BLOCK_OUT)
21-
num_ops = channel * channel * batch_size
21+
# To compute number of ops, use a x2 factor for FMA
22+
num_ops = 2 * channel * channel * batch_size
2223

2324
ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko')
2425
ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
@@ -157,14 +158,14 @@ def run_schedule(load_inp,
157158

158159
def gemm_normal(print_ir):
159160
mock = env.mock
160-
print("----- GEMM GFLOPS End-to-End Test-------")
161+
print("----- GEMM GOPS End-to-End Test-------")
161162
def run_test(header, print_ir, check_correctness):
162163
cost = run_schedule(
163164
env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy,
164165
print_ir, check_correctness)
165166
gops = (num_ops / cost.mean) / float(10 ** 9)
166167
print(header)
167-
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
168+
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
168169
with vta.build_config():
169170
run_test("NORMAL", print_ir, True)
170171

@@ -177,7 +178,7 @@ def run_test(header, print_ir):
177178
print_ir, False)
178179
gops = (num_ops / cost.mean) / float(10 ** 9)
179180
print(header)
180-
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
181+
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
181182
with vta.build_config():
182183
run_test("NORMAL", print_ir)
183184

@@ -190,7 +191,7 @@ def run_test(header, print_ir):
190191
print_ir, False)
191192
gops = (num_ops / cost.mean) / float(10 ** 9)
192193
print(header)
193-
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
194+
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
194195
with vta.build_config():
195196
run_test("NORMAL", print_ir)
196197
print("")
@@ -204,7 +205,7 @@ def run_test(header, print_ir):
204205
gops = (num_ops / cost.mean) / float(10 ** 9)
205206
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9)
206207
print(header)
207-
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
208+
print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
208209
cost.mean, gops, bandwith))
209210
with vta.build_config():
210211
run_test("NORMAL", print_ir)
@@ -219,7 +220,7 @@ def run_test(header, print_ir):
219220
gops = (num_ops / cost.mean) / float(10 ** 9)
220221
bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
221222
print(header)
222-
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
223+
print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
223224
cost.mean, gops, bandwith))
224225
with vta.build_config():
225226
run_test("NORMAL", print_ir)
@@ -235,7 +236,7 @@ def run_test(header, print_ir):
235236
gops = (num_ops / cost.mean) / float(10 ** 9)
236237
bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
237238
print(header)
238-
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
239+
print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
239240
cost.mean, gops, bandwith))
240241
with vta.build_config():
241242
run_test("NORMAL", print_ir)

vta/tests/python/integration/test_benchmark_topi_conv2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True):
4242
res = my_clip(res, 0, 127)
4343
res = topi.cast(res, "int8")
4444

45+
# To compute number of ops, use a x2 factor for FMA
4546
num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
4647

4748
a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
@@ -118,7 +119,7 @@ def conv_normal(print_ir):
118119
print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
119120
cost = verify(s, True)
120121
gops = (num_ops / cost.mean) / float(10 ** 9)
121-
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
122+
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
122123

123124
conv_normal(False)
124125

0 commit comments

Comments
 (0)