Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion python/tvm/relay/op/contrib/mrvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ def avgpool2d_base_pattern(pattern):

def globalavgpool2d_pattern():
"""Create a globalavgpool2d pattern.
review tvm/tests/python/relay/test_dataflow_pattern.py for examples
Returns
-------
pattern : dataflow_pattern.AltPattern
Expand All @@ -544,6 +543,17 @@ def globalavgpool2d_pattern():
pattern = is_op("nn.global_avg_pool2d")(wildcard())
return pattern

def globalmaxpool2d_pattern():
"""Create a globalmaxpool2d pattern.
review tvm/tests/python/relay/test_dataflow_pattern.py for examples
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the globalmaxpool2d pattern.
"""
pattern = is_op("nn.global_max_pool2d")(wildcard())
return pattern

def reshape_pattern():
pattern = is_op("reshape")(wildcard())
return pattern
Expand All @@ -552,6 +562,10 @@ def batch_flatten_pattern():
pattern = is_op("nn.batch_flatten")(wildcard())
return pattern

def squeeze_pattern():
pattern = is_op("squeeze")(wildcard())
return pattern

def layout_transform_nchw2nhwc_pattern():
pattern = is_op("layout_transform")(is_var(), wildcard(), wildcard()).has_attr(
{"src_layout": "NCHW", "dst_layout": "NHWC"}
Expand Down Expand Up @@ -596,6 +610,13 @@ def check_globalavgpool2d(extract):
call = call.args[0]
return globalavgpool2d_nhwc2nhwc(call)

def check_globalmaxpool2d(extract):
"""Check globalmaxpool2d pattern is supported by Mrvl."""
call = extract
while call.op.name != "nn.global_max_pool2d":
call = call.args[0]
return globalmaxpool2d_nhwc2nhwc(call)

def check_reshape(extract):
call = extract
while call.op.name != "reshape":
Expand All @@ -608,6 +629,12 @@ def check_batch_flatten(extract):
call = call.args[0]
return batch_flatten_mrvl(call)

def check_squeeze(extract):
call = extract
while call.op.name != "squeeze":
call = call.args[0]
return squeeze_mrvl(call)

def check_layout_transform_nchw2nhwc(extract):
call = extract
while call.op.name != "layout_transform":
Expand All @@ -634,6 +661,7 @@ def check_concat(extract):
("mrvl.maxpool2d_nhwc2nhwc", maxpool2d_pattern(), check_maxpool2d),
("mrvl.avgpool2d_nhwc2nhwc", avgpool2d_pattern(), check_avgpool2d),
("mrvl.globalavgpool2d_nhwc2nhwc", globalavgpool2d_pattern(), check_globalavgpool2d),
("mrvl.globalmaxpool2d_nhwc2nhwc", globalmaxpool2d_pattern(), check_globalmaxpool2d),
("mrvl.sum", sum_pattern(), check_sum),
("mrvl.concat", concat_pattern(), check_concat),
(
Expand All @@ -643,6 +671,7 @@ def check_concat(extract):
),
("mrvl.reshape", reshape_pattern(), check_reshape),
("mrvl.batch_flatten", batch_flatten_pattern(), check_batch_flatten),
("mrvl.squeeze", squeeze_pattern(), check_squeeze),
]


Expand Down Expand Up @@ -813,6 +842,21 @@ def globalavgpool2d_nhwc2nhwc(expr):
return True


# register a helper function to indicate that the given operator can be supported by Mrvl.
@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.mrvl")
def globalmaxpool2d_nhwc2nhwc(expr):
"""Check if the external Mrvl codegen for globalmaxpool2d_nhwc2nhwc should be used."""
attrs, args = expr.attrs, expr.args
if attrs.layout != "NHWC":
return False
data_type = args[0].checked_type
if not (len(data_type.shape) == 4 or len(data_type.shape) == 2):
return False
if (len(data_type.shape) != 4) or (data_type.dtype not in ["float32"]):
return False
return True


@tvm.ir.register_op_attr("reshape", "target.mrvl")
def reshape_mrvl(expr):
"""Check if the external Mrvl codegen for reshape should be used."""
Expand Down Expand Up @@ -846,6 +890,14 @@ def batch_flatten_mrvl(expr):
return True


@tvm.ir.register_op_attr("squeeze", "target.mrvl")
def squeeze_mrvl(expr):
"""Check if the external Mrvl codegen for squeeze should be used."""
if expr.op.name != "squeeze":
return False
return True


# register a helper function to indicate that the given operator can be supported by Mrvl.
@tvm.ir.register_op_attr("layout_transform", "target.mrvl")
def layout_transform_nchw2nhwc(expr):
Expand Down
102 changes: 102 additions & 0 deletions src/relay/backend/contrib/mrvl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer {
const CallNode* batch_flatten = nullptr;
};

/*!
* \brief A series of operators that form a Squeeze node.
*/
struct CompositeSqueezeNode {
const CallNode* squeeze = nullptr;
};

/*!
* \brief A series of operators that form a composite
* fc layer. Supports both nn.fc_ni2no and qnn.fc_ni2no.
Expand Down Expand Up @@ -278,6 +285,8 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer {
json_kernel_node = CreateCompositeMrvlAvgpool2DLayer(cn);
} else if (name == "mrvl.globalavgpool2d_nhwc2nhwc") {
json_kernel_node = CreateCompositeMrvlGlobalAvgpool2DLayer(cn);
} else if (name == "mrvl.globalmaxpool2d_nhwc2nhwc") {
json_kernel_node = CreateCompositeMrvlGlobalMaxpool2DLayer(cn);
} else if (name == "mrvl.sum") {
json_kernel_node = CreateCompositeMrvlSumLayer(cn);
} else if (name == "mrvl.concat") {
Expand All @@ -286,6 +295,8 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer {
json_kernel_node = CreateMrvlReshapeLayer(cn);
} else if (name == "mrvl.batch_flatten") {
json_kernel_node = CreateMrvlBatchFlattenLayer(cn);
} else if (name == "mrvl.squeeze") {
json_kernel_node = CreateMrvlSqueezeLayer(cn);
} else {
LOG(FATAL) << "Unrecognized Mrvl pattern: " << name;
}
Expand Down Expand Up @@ -511,6 +522,22 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer {
return nodes;
}

/*!
* \brief Extract squeeze nodes from a composite function.
* \param call The call node of the composite function.
* \return Extracted composite squeeze nodes.
*/
CompositeSqueezeNode UnpackCompositeSqueeze(const CallNode* call) {
CompositeSqueezeNode nodes{};
const auto* fn = call->op.as<FunctionNode>();
ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed.";
const auto* current_call = fn->body.as<CallNode>();
ICHECK(backend::IsOp(current_call, "squeeze"))
<< "Marvell-Compiler-ERROR-Internal::squeeze missing.";
nodes.squeeze = current_call;
return nodes;
}

/*!
* \brief Extract maxpool nodes from a composite function.
*
Expand All @@ -533,6 +560,11 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer {
<< "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing.";
ICHECK(backend::IsOp(current_call, "nn.avg_pool2d"))
<< "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing.";
} else if (mrvlLayerName == "GlobalMaxpool2D") {
ICHECK(mrvlLayerName == "GlobalMaxpool2D")
<< "Marvell-Compiler-ERROR-Internal::nn.global_max_pool2d Op missing.";
ICHECK(backend::IsOp(current_call, "nn.global_max_pool2d"))
<< "Marvell-Compiler-ERROR-Internal::nn.global_max_pool2d Op missing.";
} else {
ICHECK(mrvlLayerName == "GlobalAvgpool2D")
<< "Marvell-Compiler-ERROR-Internal::nn.global_avg_pool2d Op missing.";
Expand Down Expand Up @@ -1115,6 +1147,34 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer {
return json_node;
}

/*!
* \brief Create a JSON representation of a composite Squeeze.
*
* \param cn The call to be represented.
* \return A JSON representation of a specific operator.
*/
std::shared_ptr<JSONGraphNode> CreateMrvlSqueezeLayer(const CallNode* cn) {
CompositeSqueezeNode nodes = UnpackCompositeSqueeze(cn);
std::vector<JSONGraphNodeEntry> inputs;
std::string name = "squeeze";
inputs.push_back(VisitExpr(cn->args[0])[0]);
std::vector<int64_t> layout_vec;
GetInputTensorShapeViaArgN(nodes.squeeze, &layout_vec);
std::string data_layout;
if (layout_vec.size() == 4) {
data_layout = "NHWC";
} else {
data_layout = "NC";
}
layout_vec.clear();
std::string out_layout = "NC";
auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, name, data_layout,
"" /* no kernel_layout */, out_layout);
SetMrvlQuantAttrs(json_node, nodes.instrument_1, "1");
return json_node;
}

/*!
* \brief Create a JSON representation of a composite concat.
*
Expand Down Expand Up @@ -1304,6 +1364,48 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer {
return json_node;
}

/*!
* \brief Create a JSON representation of a composite globalmaxpooling operator.
*
* A composite function is only created when using the uint8 datatype for these operators.
*
* \param cn The call to be represented.
* \return A JSON representation of a specific operator.
*/
std::shared_ptr<JSONGraphNode> CreateCompositeMrvlGlobalMaxpool2DLayer(const CallNode* cn) {
std::string mrvlLayerName = "GlobalMaxpool2D";
std::string name = "nn.globalmaxpool2d_nhwc2nhwc";
CompositePoolNode nodes = UnpackCompositePool(cn, mrvlLayerName);

const auto* globalmaxpool_attr = nodes.pool->attrs.as<GlobalPool2DAttrs>();
ICHECK(globalmaxpool_attr)
<< "Marvell-Compiler-ERROR-Internal::Downcast to GlobalPool2DAttrs failed.";
ICHECK(globalmaxpool_attr->layout == "NHWC")
<< "Marvell-Compiler-ERROR-Internal::"
<< "Layout must be NHWC, has the module been pre-processed correctly?";

std::string data_layout = globalmaxpool_attr->layout;
std::string out_layout = globalmaxpool_attr->layout;
std::vector<JSONGraphNodeEntry> inputs;
std::vector<int64_t> kernel_layout_vec;
std::vector<int64_t> data_layout_vec;
GetInputTensorShapeViaArgN(cn, &data_layout_vec);
ICHECK(data_layout_vec.size() == 4);
kernel_layout_vec.push_back(data_layout_vec[1]);
kernel_layout_vec.push_back(data_layout_vec[2]);
inputs.push_back(VisitExpr(cn->args[0])[0]);

// op_type_ is "kernel"
auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
SetCallNodeAttribute(json_node, nodes.pool);
JsonNodeSetVecAttr(json_node, "kernel_layout_shape", kernel_layout_vec);
if (nodes.pad) SetMrvlLayerPadAttrs(json_node, nodes.pad);

SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, "HW",
out_layout);
return json_node;
}

/*!
* \brief Create a JSON representation of an OpNode layer.
*
Expand Down
108 changes: 108 additions & 0 deletions tests/python/contrib/test_mrvl/test_mrvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,115 @@ def get_graph():
run_and_verify_func(get_graph())


@requires_mrvl
def test_maxpool2d():
"""Test maxpool2d operator for "mrvl" targets"""

def get_graph():
x = relay.var("x", shape=(1, 3, 224, 224))
arr = np.random.rand(16, 3, 3, 3).astype("float32")
w = relay.const(arr)
y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3])
y = relay.nn.max_pool2d(y)
func = relay.Function([x], y)
mod = tvm.IRModule()
mod["main"] = func
option_dict = {"num_tiles": 1}
verify_codegen(mod, params={}, tvm_ops=1, contains="mrvl.maxpool2d_nhwc2nhwc")
return func, {"x": (1, 3, 224, 224)}, [], option_dict

run_and_verify_func(get_graph())


@requires_mrvl
def test_avgpool2d():
"""Test avgpool2d operator for "mrvl" targets"""

def get_graph():
x = relay.var("x", shape=(1, 3, 224, 224))
arr = np.random.rand(16, 3, 3, 3).astype("float32")
w = relay.const(arr)
y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3])
y = relay.nn.avg_pool2d(y)
func = relay.Function([x], y)
mod = tvm.IRModule()
mod["main"] = func
option_dict = {"num_tiles": 1}
verify_codegen(mod, params={}, tvm_ops=1, contains="mrvl.avgpool2d_nhwc2nhwc")
return func, {"x": (1, 3, 224, 224)}, [], option_dict

run_and_verify_func(get_graph())


@requires_mrvl
def test_globalavgpool2d():
"""Test globalavgpool2d operator for "mrvl" targets"""

def get_graph():
x = relay.var("x", shape=(1, 3, 224, 224))
arr = np.random.rand(16, 3, 3, 3).astype("float32")
w = relay.const(arr)
y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3])
y = relay.nn.global_avg_pool2d(y)
func = relay.Function([x], y)
mod = tvm.IRModule()
mod["main"] = func
option_dict = {"num_tiles": 1}
verify_codegen(mod, params={}, tvm_ops=1, contains="mrvl.globalavgpool2d_nhwc2nhwc")
return func, {"x": (1, 3, 224, 224)}, [], option_dict

run_and_verify_func(get_graph())


@requires_mrvl
def test_globalmaxpool2d():
"""Test globalmaxpool2d operator for "mrvl" targets"""

def get_graph():
x = relay.var("x", shape=(1, 3, 224, 224))
arr = np.random.rand(16, 3, 3, 3).astype("float32")
w = relay.const(arr)
y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3])
y = relay.nn.global_max_pool2d(y)
func = relay.Function([x], y)
params = {}
params["w"] = arr
mod = tvm.IRModule()
mod["main"] = func
option_dict = {"num_tiles": 1}
verify_codegen(mod, params=params, tvm_ops=2, contains="mrvl.globalmaxpool2d_nhwc2nhwc")
return func, {"x": (1, 3, 224, 224), "w": (16, 3, 3, 3)}, ["w"], option_dict

run_and_verify_func(get_graph())


@requires_mrvl
def test_squeeze():
"""Test squeeze operator for "mrvl" targets"""

def get_graph():
x = relay.var("x", shape=(1, 3, 224, 224))
arr = np.random.rand(16, 3, 3, 3).astype("float32")
w = relay.const(arr)
y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3])
y = relay.reshape(y, newshape=(1, 1, 16, 112, 112))
y = relay.squeeze(y, axis=[0, 1])
func = relay.Function([x], y)
mod = tvm.IRModule()
mod["main"] = func
option_dict = {"num_tiles": 1}
verify_codegen(mod, params={}, tvm_ops=3, contains="mrvl.squeeze")
return func, {"x": (1, 3, 224, 224)}, [], option_dict

run_and_verify_func(get_graph())


if __name__ == "__main__":
test_mrvl_fuse()
test_conv2d()
test_dense()
test_maxpool2d()
test_avgpool2d()
test_globalavgpool2d()
test_globalmaxpool2d()
test_squeeze()