Skip to content

Commit

Permalink
Update Graph Support for Batching, Fix Swapping (apache#37)
Browse files Browse the repository at this point in the history
* fix graph transform for batch dimension

* fix

* fix
  • Loading branch information
tqchen committed Jul 12, 2018
1 parent a96a4a9 commit 8c9758b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 39 deletions.
25 changes: 19 additions & 6 deletions vta/examples/resnet18/pynq/imagenet_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tvm
from nnvm.compiler import graph_attr
import vta
import vta.testing
import os
import numpy as np
from PIL import Image
Expand All @@ -12,7 +13,8 @@
import wget
from tvm.contrib import graph_runtime, rpc, util

factor = 16
bfactor = 1
cfactor = 16
host = "pynq"
port = 9091
verbose = False
Expand All @@ -38,6 +40,10 @@
target = tvm.target.create("llvm -device=vta")
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"

if vta.get_env().TARGET == "sim":
target_host = "llvm"


synset = eval(open(os.path.join(CATEG_FILE)).read())
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))

Expand Down Expand Up @@ -105,7 +111,7 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
if target.device_name == "vta":
sym = vta.graph.pack(sym, shape_dict, factor)
sym = vta.graph.pack(sym, shape_dict, bfactor, cfactor)

graph_attr.set_shape_inputs(sym, shape_dict)
sym = sym.apply("InferShape")
Expand All @@ -127,7 +133,13 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
assert tvm.module.enabled("rpc")
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
remote = rpc.connect(host, port)

if vta.get_env().TARGET == "sim":
remote = rpc.LocalSession()
print("local session")
else:
remote = rpc.connect(host, port)

remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
Expand All @@ -154,16 +166,17 @@ def run_e2e(graph):
print("t-cost=%g" % tcost.mean)


def run_layer(old_graph):
def run_layer(old_graph, layer_begin, layer_end):
"""Run a certain layer."""
for layer_id in range(1, 2):
for layer_id in range(layer_begin, layer_end):
print("run resnet[%d]..."% (layer_id))
graph = mark_nop(old_graph, layer_id)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('data', tvm.nd.array(x.astype("float32")))
m.set_input(**params)
# execute
timer = m.module.time_evaluator("run", ctx, number=10)
timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()
print("resnet[%d]: %g\n"% (layer_id, tcost.mean))

Expand Down
76 changes: 44 additions & 32 deletions vta/python/vta/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,58 @@
from nnvm.compiler import graph_attr, graph_util


def _pack_channel(data, dshape, factor):
def _pack_batch_channel(data, dshape, bfactor, cfactor):
"""Pack the data channel dimension.
"""
assert dshape[1] % factor == 0
assert dshape[0] % bfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0], dshape[1] // factor,
factor, dshape[2], dshape[3]))
shape=(dshape[0] // bfactor, bfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 1, 3, 4, 2))
data, axes=(0, 2, 4, 5, 1, 3))
return data


def _unpack_channel(data, old_shape):
def _unpack_batch_channel(data, old_shape):
"""Unpack the data channel dimension.
"""
data = nnvm.sym.transpose(data, axes=(0, 1, 4, 2, 3))
data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3))
data = nnvm.sym.reshape(data, shape=old_shape)
return data


def _pack_weight(data, dshape, factor):
def _pack_weight(data, dshape, cfactor):
"""Pack the weight into packed format.
"""
assert len(dshape) == 4
assert dshape[0] % factor == 0
assert dshape[1] % factor == 0
assert dshape[0] % cfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // factor, factor,
dshape[1] // factor, factor,
shape=(dshape[0] // cfactor, cfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data


def _pack_bias(data, dshape, factor):
def _pack_bias(data, dshape, bfactor, cfactor):
"""Pack the bias parameter.
"""
assert len(dshape) == 3
assert dshape[0] % factor == 0
assert dshape[0] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // factor,
factor, dshape[1], dshape[2]))
shape=(dshape[0] // cfactor,
cfactor, dshape[1],
dshape[2], 1))
data = nnvm.sym.transpose(
data, axes=(0, 2, 3, 1))
data, axes=(0, 2, 3, 4, 1))
# broadcast batch dimension to bfactor
data = nnvm.sym.broadcast_to(
data,
shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
return data


Expand Down Expand Up @@ -245,8 +252,8 @@ def _clean_cast(node, target_type):
return ret


def pack(graph, shape_dict, factor, start_name=None):
"""Pack the graph into channel packed format.
def pack(graph, shape_dict, bfactor, cfactor, start_name=None):
"""Pack the graph into batch&channel packed format.
Parameters
----------
Expand All @@ -256,8 +263,11 @@ def pack(graph, shape_dict, factor, start_name=None):
shape_dict : dict of str to shapex
The input shape.
factor : int
The packing factor
bfactor : int
The packing factor in batch
cfactor : int
The packing factor in channel
start_name: str, optional
Start name start packing from certain known node.
Expand Down Expand Up @@ -290,42 +300,44 @@ def pack(graph, shape_dict, factor, start_name=None):
new_node = nnvm.symbol.Variable(node_name)
if start_name and node_name == start_name:
start_pack = True
new_node = _pack_channel(new_node, oshape, factor)
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif op_name == "max_pool2d":
assert not start_pack
start_pack = True
new_node = get_clone(children, op_name, node_name, attrs)
new_node = _pack_channel(new_node, oshape, factor)
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif op_name == "global_avg_pool2d":
if start_pack:
start_pack = False
children[0] = _unpack_channel(children[0], ishape[0])
children[0] = _unpack_batch_channel(children[0], ishape[0])
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "quantized_conv2d":
if start_pack:
attrs["pack_channel"] = str(factor)
attrs["pack_batch"] = str(bfactor)
attrs["pack_channel"] = str(cfactor)
data, weight = children
weight = _pack_weight(weight, ishape[1], factor)
weight = _pack_weight(weight, ishape[1], cfactor)
new_node = nnvm.sym.quantized_conv2d(
data, weight, name=node_name, **attrs)
elif counter == 1:
attrs["pack_channel"] = str(factor)
attrs["pack_batch"] = str(bfactor)
attrs["pack_channel"] = str(cfactor)
data, weight = children
data = _pack_channel(data, ishape[0], factor)
weight = _pack_weight(weight, ishape[1], factor)
data = _pack_batch_channel(data, ishape[0], bfactor, cfactor)
weight = _pack_weight(weight, ishape[1], cfactor)
new_node = nnvm.sym.quantized_conv2d(
data, weight, name=node_name, **attrs)
new_node = _unpack_channel(new_node, oshape)
new_node = _unpack_batch_channel(new_node, oshape)
counter = counter + 1
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("broadcast"):
if start_pack:
assert len(ishape[1]) == 3
children[1] = _pack_bias(children[1], ishape[1], factor)
children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor)
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
Expand All @@ -341,7 +353,7 @@ def pack(graph, shape_dict, factor, start_name=None):
ret = node_map[graph.index.output_entries[0][0]]
if start_pack:
oshape = shape[graph.index.output_entries[0][0]]
ret = _unpack_channel(ret, oshape)
ret = _unpack_batch_channel(ret, oshape)
graph = nnvm.graph.create(ret)
graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape")
Expand Down
4 changes: 3 additions & 1 deletion vta/src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ class UopQueue : public BaseQueue {
}
assert(num_op <= kMaxNumUop);
uint32_t uop_begin = 0;
if (sram_end_ + num_op > kMaxElems) {
if (sram_end_ + num_op > kMaxNumUop) {
// Need to evict
cache_ptr_ = 0;
sram_begin_ = 0;
sram_end_ = num_op;
} else {
uop_begin = sram_end_;
Expand All @@ -388,6 +389,7 @@ class UopQueue : public BaseQueue {
dram_end_ += num_op;
kernel->sram_begin_ = uop_begin;
kernel->sram_end_ = sram_end_;
CHECK(kernel->cached());
assert(uop_begin != sram_end_);
cache_.insert(cache_.begin() + cache_ptr_, kernel);
cache_.erase(cache_.begin() + evict_begin, cache_.begin() + cache_ptr_);
Expand Down
1 change: 1 addition & 0 deletions vta/src/sim/sim_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class DRAM {
*/
void Free(void* data) {
std::lock_guard<std::mutex> lock(mutex_);
if (pmap_.size() == 0) return;
auto it = pmap_.find(data);
CHECK(it != pmap_.end());
Page* p = it->second.get();
Expand Down

0 comments on commit 8c9758b

Please sign in to comment.