Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] GPU support for tensorflow models. (apache#1718)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and tqchen committed Sep 20, 2018
1 parent ae5a28d commit fdf795a
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 39 deletions.
90 changes: 64 additions & 26 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __call__(self, inputs, attrs, *args):
self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
# Retain the names
try:
attrs['name'] = attrs['_node_name']
Expand Down Expand Up @@ -121,6 +122,9 @@ def _pooling(name):
def _impl(inputs, attr, params):

attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False

input_shape = attr['_input_shapes'][inputs[0]][0]

if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
Expand All @@ -129,11 +133,17 @@ def _impl(inputs, attr, params):
else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]][0]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW"
flip_layout = True

# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])

# Fix padding
input_shapes = attr['_input_shapes'][inputs[0]]
attr['padding'] = attr['padding'].decode("utf-8")

if attr['padding'] == 'VALID':
Expand All @@ -142,11 +152,11 @@ def _impl(inputs, attr, params):
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
in_h = input_shape[1]
in_w = input_shape[2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]
in_h = input_shape[2]
in_w = input_shape[3]

pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
Expand All @@ -158,41 +168,61 @@ def _impl(inputs, attr, params):
if name == "avg_pool":
attr['count_include_pad'] = False

return AttrCvt(
out = AttrCvt(
op_name=_dimension_picker(name),
transforms={
'kernel_shape':'pool_size',
'data_format':'layout'},
ignores=['ksize'],
extras={'ceil_mode': False},
custom_check=_dimension_constraint())(inputs, attr)

if flip_layout:
out = _sym.transpose(out, axes=(0, 2, 3, 1))

return out
return _impl

def _conv(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
input_shapes = attr['_input_shapes'][inputs[0]]
flip_layout = False

# Extract kernel shape from params
conv_param_weights = params[inputs[1].list_output_names()[0]]
input_shape = attr['_input_shapes'][inputs[0]][0]
weights_shape = params[inputs[1].list_output_names()[0]].shape

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
if opname == 'conv':
weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1))

attr['data_format'] = "NCHW"
flip_layout = True

if attr['data_format'] == 'NHWC':
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
kernel_h, kernel_w, _, depth_mult = weights_shape
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[3]
attr['channels'] = weights_shape[3]
else:
attr['channels'] = input_shapes[0][3] * depth_mult
attr['channels'] = input_shape[3] * depth_mult

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
depth_mult, _, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[1]
attr['channels'] = weights_shape[0]
else:
attr['channels'] = input_shapes[0][1] * depth_mult
attr['channels'] = input_shape[0] * depth_mult
if attr['channels'] < 0:
attr['channels'] *= -1

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
Expand All @@ -215,11 +245,11 @@ def _impl(inputs, attr, params):
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
in_h = input_shape[1]
in_w = input_shape[2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]
in_h = input_shape[2]
in_w = input_shape[3]

pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
Expand Down Expand Up @@ -248,7 +278,7 @@ def _impl(inputs, attr, params):
else:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'

return AttrCvt(
out = AttrCvt(
op_name=_dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
Expand All @@ -257,6 +287,11 @@ def _impl(inputs, attr, params):
'group': ('groups', 1)},
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr)

if flip_layout:
out = _sym.transpose(out, axes=(0, 2, 3, 1))

return out
return _impl

def _decode_image():
Expand Down Expand Up @@ -305,7 +340,7 @@ def _matmul():
def _impl(inputs, attr, params):
channels = _infer_channels(inputs[1], params, not attr['transpose_b'])
if attr['transpose_a']:
inputs[0] = _sym.transpose(inputs[0], axis(1, 0))
inputs[0] = _sym.transpose(inputs[0], axes(1, 0))
if not attr['transpose_b']:
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
return AttrCvt(op_name="dense",
Expand Down Expand Up @@ -948,7 +983,7 @@ def __init__(self):
self._num_param = 0
self._num_rnn_layer = False

def from_tensorflow(self, graph):
def from_tensorflow(self, graph, layout="NHWC"):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM.
Expand Down Expand Up @@ -1036,6 +1071,9 @@ def from_tensorflow(self, graph):
# Pass the node name too in attr
attr["_node_name"] = node.name

# Pass the target layout
attr["_target_layout"] = layout

#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
Expand Down Expand Up @@ -1265,7 +1303,7 @@ def _fix_extranodes(self, op_name, attr, inputs):

return inputs

def from_tensorflow(graph):
def from_tensorflow(graph, layout="NHWC"):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically.
Expand All @@ -1283,5 +1321,5 @@ def from_tensorflow(graph):
Dict of converted parameters stored in tvm.ndarray format
"""
g = GraphProto()
sym, params = g.from_tensorflow(graph)
sym, params = g.from_tensorflow(graph, layout)
return sym, params
28 changes: 20 additions & 8 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@
#######################################################################
# Generic run functions for TVM & tensorflow
# ------------------------------------------
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype):
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype, target='llvm'):
""" Generic function to compile on nnvm and execute on tvm """

sym, params = nnvm.frontend.from_tensorflow(graph_def)
target = 'llvm'
layout = None
if target == "cuda":
layout = "NCHW"

sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
target_host = 'llvm'
if isinstance(input_data, list):
shape_dict = {}
dtype_dict = {}
Expand All @@ -41,10 +45,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype)
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}

graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict,
dtype=dtype_dict, params=params)

ctx = tvm.cpu(0)
ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
Expand Down Expand Up @@ -106,9 +110,17 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
)

tf_output = run_tf_graph(sess, in_data, in_name, out_name)
tvm_output = run_tvm_graph(final_graph_def, in_data,
in_node, tf_output.shape, tf_output.dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)

for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue

tvm_output = run_tvm_graph(final_graph_def, in_data,
in_node, tf_output.shape, tf_output.dtype, target=device)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)

sess.close()

#######################################################################
Expand Down
18 changes: 13 additions & 5 deletions tutorials/nnvm/from_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@
lable_map = 'imagenet_synset_to_human_label_map.txt'
lable_map_url = os.path.join(repo_base, lable_map)

# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'llvm'
target_host = 'llvm'
layout = None
ctx = tvm.cpu(0)

######################################################################
# Download required files
Expand Down Expand Up @@ -99,7 +109,7 @@
# Results:
# sym: nnvm graph for given tensorflow protobuf.
# params: params converted from tensorflow params (tensor protobuf).
sym, params = nnvm.frontend.from_tensorflow(graph_def)
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)

print ("Tensorflow protobuf imported as nnvm graph")
######################################################################
Expand All @@ -113,18 +123,16 @@
# lib: target library which can be deployed on target with tvm runtime.

import nnvm.compiler
target = 'llvm'
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype=dtype_dict, params=params)
graph, lib, params = nnvm.compiler.build(sym, shape=shape_dict, target=target, target_host=target_host, dtype=dtype_dict, params=params)

######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now we can try deploying the NNVM compiled model on cpu target.
# Now we can try deploying the NNVM compiled model on target.

from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
Expand Down

0 comments on commit fdf795a

Please sign in to comment.