Skip to content

Commit

Permalink
[NNVM] Add missed part of annotation (#10)
Browse files Browse the repository at this point in the history
* add missed part of annotation

* fix check_computation and slice_like

* keep _build as before

* fix vta failure
  • Loading branch information
zhiics authored and wweic committed Mar 11, 2019
1 parent f37881d commit dc06e32
Show file tree
Hide file tree
Showing 17 changed files with 513 additions and 435 deletions.
109 changes: 57 additions & 52 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class AnnotationType(IntEnum):
"""The purpose of annotation."""
TARGET = 1 # Only set target to the node attribute.
HOMO_TARGET = 1 # Only set the same target to the node attribute.
DEVICE_TARGET = 2 # Annotate both device type and target info to a node.
COPY_INSERTION = 3 # Annotate device type and target. Insert copy node.

Expand All @@ -44,6 +44,8 @@ class BuildConfig(object):
"opt_level": 2,
"add_pass": None,
"ext_accel": None,
"fallback_device": None,
"op_name_device": None,
}
def __init__(self, **kwargs):
self._old_scope = None
Expand Down Expand Up @@ -105,6 +107,13 @@ def build_config(**kwargs):
ext_accel: str
External accelerator for optimizing the operators it supports in the whole graph.
fallback_device : str or tvm.TVMContext
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.
op_name_device : dict of str to str or tvm.TVMContext.
A dictionary contains operator name to device context mapping.
Returns
-------
config: BuildConfig
Expand All @@ -131,11 +140,11 @@ def _lower(sch, inputs, func_name, graph):
f, (tvm.container.Array, tuple, list)) else [f]


@tvm.register_func("nnvm.compiler.build_module")
def _build(funcs, target_host):
@tvm.register_func("nnvm.compiler.build_target")
def _build(funcs, target, target_host):
if target_host == "":
target_host = None
return tvm.build(funcs, target_host=target_host)
return tvm.build(funcs, target=target, target_host=target_host)


def _update_shape_dtype(shape, dtype, params):
Expand Down Expand Up @@ -203,8 +212,7 @@ def optimize(graph, shape, dtype="float32", layout=None, target=None):


def build(graph, target=None, shape=None, dtype="float32",
params=None, target_host=None, layout=None, op_name_device=None,
fallback_device=None):
params=None, target_host=None, layout=None):
"""Build graph into runtime library.
The build function will optimize the graph and do the compilation.
Expand All @@ -218,8 +226,9 @@ def build(graph, target=None, shape=None, dtype="float32",
graph : Graph
The graph to be used in lowering
target : str or :any:`tvm.target.Target`, optional
The build target
target : str, :any:`tvm.target.Target`, or a str to str dict, optional
The build target or a dictionay contains the device name to compilation
target.
shape : dict of str to tuple, optional
The input shape to the graph
Expand All @@ -244,12 +253,6 @@ def build(graph, target=None, shape=None, dtype="float32",
layout : dict of str to str or str optional
The input layout
op_name_device : dict of str to int.
A dictionary contains operator name to device mapping.
fallback_device : TVMContext.
The fallback device.
Returns
-------
graph : Graph
Expand Down Expand Up @@ -313,20 +316,23 @@ def build(graph, target=None, shape=None, dtype="float32",
if _all_var_init:
init_var = initialize_variables(shape, dtype)

_annotate_graph(graph, device_target, op_name_device, fallback_device)
graph = _annotate_graph(graph, device_target,
AnnotationType.DEVICE_TARGET)
# Apply optimization
graph = optimize(graph, shape, dtype, layout, target)
graph = optimize(graph, shape, dtype, layout)

# Clear extra params without nodes.
_remove_noref_params(params, graph)

_annotate_graph(graph, device_target)
graph = _annotate_graph(graph, device_target,
AnnotationType.HOMO_TARGET)
# Precompute prune
if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params)
_annotate_graph(graph, device_target, op_name_device, fallback_device,
insert_copy_node=True)
graph = _annotate_graph(graph, device_target,
AnnotationType.COPY_INSERTION)

# Operator Fusion and generation
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
Expand All @@ -352,14 +358,8 @@ def build(graph, target=None, shape=None, dtype="float32",

def _annotate_graph(graph,
device_target,
op_name_device=None,
fallback_device=None,
insert_copy_node=False):
"""Helper function to anntoate the graph. Both the target and the device
info of a graph node will be annotated if `op_name_device` is set.
Otherwise, only the target info will be attached. `insert_copy_node`
indicates if we need to insert cross device data copy node. It is only for
heterogeneous execution purpose.
annotation_type):
"""Helper function to anntoate the graph according to the annotation type.
Parameters
----------
Expand All @@ -370,37 +370,42 @@ def _annotate_graph(graph,
A dictionary contain device type to compilation target pairs that will
be used to build the graph.
op_name_device : dict of str to int.
A dictionary contains operator name to device mapping.
fallback_device : TVMContext.
The fallback device.
insert_copy_node : bool.
A bool value indicates wheter or not cross device data copy node is
required.
annotation_type : AnnotationType.
The annotation type. This is used to indicate if we annotate all nodes
to the same type (AnnotationType.HOMO_TARGET), attach different target
to different nodes (AnnotationType.DEVICE_TARGET), or attach target and
insert across device copy nodes (AnnotationType.COPY_INSERTION).
Returns
-------
graph : Graph.
The updated graph.
"""
annotation_type = AnnotationType.TARGET
if op_name_device:
annotation_type = AnnotationType.COPY_INSERTION if insert_copy_node \
else AnnotationType.DEVICE_TARGET
if not isinstance(op_name_device, dict):
raise ValueError("op_name_device must be a dictionary.")
fallback_device = fallback_device if fallback_device else tvm.cpu(0)
if not isinstance(fallback_device, TVMContext):
raise ValueError("fallback_device must be the type of TVMContext.")
op_name_device.update((name, tvm.context(dev).device_type)
for name, dev in op_name_device.items())
graph._set_json_attr("fallback", fallback_device.device_type, "int")
graph._set_json_attr("op_name", list(op_name_device.keys()),
"list_str")
graph._set_json_attr("op_device", list(op_name_device.values()),
"list_int")
if not isinstance(annotation_type, AnnotationType):
raise ValueError("annotation_type must be the type of AnnotationType")

if annotation_type != AnnotationType.HOMO_TARGET:
# Heterogeneous execution.
if len(device_target) > 1 or 0 not in device_target:
op_name_device = BuildConfig.current.op_name_device
op_name_device = op_name_device if op_name_device else {}
if not isinstance(op_name_device, dict):
raise ValueError("op_name_device must be a dictionary of operator "
"name to device context.")
fallback_device = BuildConfig.current.fallback_device
fallback_device = fallback_device if fallback_device else tvm.cpu(0)
if not isinstance(fallback_device, TVMContext):
raise ValueError("fallback_device must be the type of TVMContext.")
op_name_device.update((name, tvm.context(dev).device_type)
for name, dev in op_name_device.items())
graph._set_json_attr("fallback", fallback_device.device_type, "int")
graph._set_json_attr("op_name", list(op_name_device.keys()),
"list_str")
graph._set_json_attr("op_device", list(op_name_device.values()),
"list_int")
else:
# Homogeneous execution.
annotation_type = AnnotationType.HOMO_TARGET

graph._set_json_attr("annotation_type", int(annotation_type), "int")
graph._set_json_attr("device_type", list(device_target.keys()), "list_int")
Expand Down
135 changes: 73 additions & 62 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def schedule_log_softmax(_, outs, target):
@reg.register_compute("dense")
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
return topi.nn.dense(inputs[0], inputs[1])
with tvm.target.create(attrs.get_str("target")):
if attrs.get_bool("use_bias"):
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
return topi.nn.dense(inputs[0], inputs[1])

@reg.register_schedule("dense")
def schedule_dense(_, outs, target):
Expand Down Expand Up @@ -95,37 +96,42 @@ def compute_conv2d(attrs, inputs, _):
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8':
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
# pylint: enable=assignment-from-no-return
elif groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype)
elif layout == "NCHW" and \
groups == get_const_int(inputs[0].shape[1]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
elif layout in ["NCHW", "NCHW4c"]:
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
out_dtype=out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and \
groups == get_const_int(inputs[0].shape[3]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
else:
raise ValueError("not support arbitrary group number for now")
with tvm.target.create(attrs.get_str("target")):
if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8':
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
# pylint: enable=assignment-from-no-return
elif groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding, dilation, layout,
out_dtype=out_dtype)
elif layout == "NCHW" and \
groups == get_const_int(inputs[0].shape[1]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation,
out_dtype=out_dtype)
elif layout in ["NCHW", "NCHW4c"]:
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides,
padding, dilation, groups,
out_dtype=out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and \
groups == get_const_int(inputs[0].shape[3]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation,
out_dtype=out_dtype)
else:
raise ValueError("not support arbitrary group number for now")

if attrs.get_bool("use_bias"):
bias = inputs[2]
expand_axis = 1 if layout == "NCHW" else 0
bias = topi.expand_dims(bias, axis=expand_axis, num_newaxis=2)
out = topi.add(out, bias)
return out
if attrs.get_bool("use_bias"):
bias = inputs[2]
expand_axis = 1 if layout == "NCHW" else 0
bias = topi.expand_dims(bias, axis=expand_axis, num_newaxis=2)
out = topi.add(out, bias)
return out

@reg.register_schedule("conv2d")
def schedule_conv2d(attrs, outs, target):
Expand Down Expand Up @@ -171,7 +177,8 @@ def _reshape(*args, **kwargs):
return raw_reshape(*args, **kwargs)
sym.reshape = _reshape

return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, sym)
with tvm.target.create(attrs.get_str("target")):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, sym)

reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand All @@ -194,21 +201,24 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
_, in_channel_chunk, _, _, in_channel_block = get_const_tuple(inputs[0].shape)
in_channel = in_channel_chunk * in_channel_block
assert dilation == (1, 1), "not support dilate now"
if groups == 1:
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
layout, out_layout, out_dtype)
elif groups == in_channel and groups == out_channel:
out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding,
dilation, layout, out_layout, out_dtype)
# pylint: enable=assignment-from-no-return
else:
raise ValueError("not support arbitrary group number > 1 for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out

with tvm.target.create(attrs.get_str("target")):
if groups == 1:
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding,
dilation, layout, out_layout, out_dtype)
elif groups == in_channel and groups == out_channel:
out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides,
padding, dilation, layout,
out_layout, out_dtype)
# pylint: enable=assignment-from-no-return
else:
raise ValueError("not support arbitrary group number > 1 for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out

@reg.register_schedule("_contrib_conv2d_NCHWc")
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
Expand All @@ -228,7 +238,7 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):

@reg.register_compute("_contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, _):
with tvm.target.create(attrs.get_string("target")):
with tvm.target.create(attrs.get_str("target")):
return topi.nn.conv2d_winograd_weight_transform(
inputs[0], attrs.get_int('tile_size'))

Expand All @@ -254,16 +264,17 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, _):
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"

# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_winograd_without_weight_transform(
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype,
tile_size)
with tvm.target.create(attrs.get_str("target")):
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_winograd_without_weight_transform(
inputs[0], inputs[1], strides, padding, dilation, layout,
out_dtype, tile_size)

if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out

@reg.register_schedule("_contrib_conv2d_winograd_without_weight_transform")
def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
Expand All @@ -290,7 +301,7 @@ def compute_conv2d_transpose(attrs, inputs, _):
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"

with tvm.target.create(attrs.get_string("target")):
with tvm.target.create(attrs.get_str("target")):
out = topi.nn.conv2d_transpose_nchw(inputs[0], inputs[1], strides,
padding, out_dtype)
if attrs.get_bool("use_bias"):
Expand Down Expand Up @@ -388,7 +399,7 @@ def compute_lrn(attrs, inputs, _):
alpha = attrs.get_float("alpha")
beta = attrs.get_float("beta")
bias = attrs.get_float("bias")
with tvm.target.create(attrs.get_string("target")):
with tvm.target.create(attrs.get_str("target")):
return topi.nn.lrn(inputs[0], size, axis, alpha, beta, bias)

@reg.register_schedule("lrn")
Expand All @@ -404,7 +415,7 @@ def compute_l2_normalize(attrs, inputs, _):
"""Compute definition of l2 normalize"""
eps = attrs.get_float("eps")
axis = attrs.get_int_tuple("axis")
with tvm.target.create(attrs.get_string("target")):
with tvm.target.create(attrs.get_str("target")):
return topi.nn.l2_normalize(inputs[0], eps, axis)

@reg.register_schedule("l2_normalize")
Expand Down
Loading

0 comments on commit dc06e32

Please sign in to comment.