Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC-DNNL] enable conv3d->bn folding #10837

Merged
merged 5 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 79 additions & 29 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/transform.h>
#include <tvm/tir/data_layout.h>

#include "../backend/utils.h"
#include "../op/tensor/transform.h"
#include "pass_utils.h"
#include "pattern_utils.h"
Expand Down Expand Up @@ -492,11 +493,11 @@ RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);

// Consumer operators
// Conv2D send out requirement of axis folding.
Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// Conv send out requirement of axis folding.
template <typename ATTRS>
Array<Message> ConvForwardPrep(const Call& call, const ATTRS* param, const Message& out_message) {
// TODO(tvm-team) support general data layout
// by transforming weight
const auto* param = call->attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
Expand All @@ -512,8 +513,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (param->groups == 1 || is_depthwise_conv2d) {
bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
if (param->groups == 1 || is_depthwise_conv) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
Expand All @@ -529,14 +530,14 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
}

// Conv2D consumes the scale axis during transformation.
Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
const Message& message) {
template <typename ATTRS>
Expr ConvForwardRewrite(const Call& ref_call, const ATTRS* param, const Array<Expr>& new_args,
const Message& message) {
// if data do not have scale, normal transform path.
const auto* sdata = new_args[0].as<ScaledExprNode>();
const auto* sweight = new_args[1].as<ScaledExprNode>();
if (sdata == nullptr) return Expr();
if (sweight != nullptr) return Expr();
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
Expand All @@ -552,13 +553,13 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
ICHECK(is_simple || is_blocking);

// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
ICHECK(param->groups == 1 || is_depthwise_conv2d);
bool is_depthwise_conv = IsDepthwiseConv(ref_call, param, kernel_layout);
ICHECK(param->groups == 1 || is_depthwise_conv);

Expr weight = new_args[1];

// match the ic_axis
if (is_depthwise_conv2d) {
if (is_depthwise_conv) {
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
weight = Multiply(weight, scale);
Expand All @@ -580,14 +581,38 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
if (!weight.defined()) return Expr();
}
}
// return transformed conv2d
// return transformed conv
return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}

RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
Array<Message> PreConvForwardPrep(const Call& call, const Message& out_message) {
if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
const auto* param = call->attrs.as<Conv2DAttrs>();
return ConvForwardPrep(call, param, out_message);
}
const auto* param = call->attrs.as<Conv3DAttrs>();
return ConvForwardPrep(call, param, out_message);
}

Expr PreConvForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
const Message& message) {
if (backend::IsOp(ref_call.as<CallNode>(), "nn.conv2d")) {
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
return ConvForwardRewrite(ref_call, param, new_args, message);
}
const auto* param = ref_call->attrs.as<Conv3DAttrs>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ICHECK(param)

Add that to other places as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

return ConvForwardRewrite(ref_call, param, new_args, message);
}

RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);

RELAY_REGISTER_OP("nn.conv3d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);

RELAY_REGISTER_OP("nn.conv3d")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);

// Dense send out requirement of axis folding.
Array<Message> DenseForwardPrep(const Call& call, const Message& out_message) {
Expand Down Expand Up @@ -937,9 +962,9 @@ RELAY_REGISTER_OP("multiply")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);

// Consumer operators
// Conv2D send out requirement of axis folding.
Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* param = call->attrs.as<Conv2DAttrs>();
// Conv send out requirement of axis folding.
template <typename ATTRS>
Message ConvBackwardPrep(const Call& call, const ATTRS* param, const Array<Message>& in_messages) {
ICHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
Expand All @@ -952,10 +977,10 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// By using a unified layout transformation.
// We only need to change the Prep and Mutate function.
//
// only handle depthwise or full conv2d.
// only handle depthwise or full conv.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (param->groups == 1 || is_depthwise_conv2d) {
bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
if (param->groups == 1 || is_depthwise_conv) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
Expand All @@ -970,13 +995,13 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
return NullValue<Message>();
}

// Conv2D consumes the scale axis during transformation.
Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
// Conv consumes the scale axis during transformation.
template <typename ATTRS>
Expr ConvBackwardTransform(const Call& call, const ATTRS* param, const Message& message,
const Expr& scale, const BackwardTransformer& transformer) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
const auto* param = call->attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
Expand All @@ -988,9 +1013,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
ICHECK(param->groups == 1 || is_depthwise_conv2d);
// Check it must be depthwise or full conv.
bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
ICHECK(param->groups == 1 || is_depthwise_conv);
bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
ICHECK(is_simple || is_blocking);
Expand All @@ -1012,11 +1037,36 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
return Call(call->op, {data, weight}, call->attrs, call->type_args);
}

Message PreConvBackwardPrep(const Call& call, const Array<Message>& in_messages) {
if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
const auto* param = call->attrs.as<Conv2DAttrs>();
return ConvBackwardPrep(call, param, in_messages);
}
const auto* param = call->attrs.as<Conv3DAttrs>();
return ConvBackwardPrep(call, param, in_messages);
}

Expr PreConvBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
const auto* param = call->attrs.as<Conv2DAttrs>();
return ConvBackwardTransform(call, param, message, scale, transformer);
}
const auto* param = call->attrs.as<Conv3DAttrs>();
return ConvBackwardTransform(call, param, message, scale, transformer);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);

RELAY_REGISTER_OP("nn.conv3d")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);

RELAY_REGISTER_OP("nn.conv3d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);

Message BiasAddBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const BiasAddAttrs* attrs = call->attrs.as<BiasAddAttrs>();
Expand Down
18 changes: 10 additions & 8 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include <utility>
#include <vector>

#include "../backend/utils.h"
#include "../op/make_op.h"

namespace tvm {
Expand Down Expand Up @@ -183,16 +184,17 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array<Intege
}

/*!
* \brief Check if the call is depthwise conv2d.
* \brief Check if the call is depthwise conv3d.
*
* \param call The conv2d call.
* \param param The conv2d attributes.
* \return Whether it is depthwise_conv2d.
* \param call The conv call.
* \param param The conv attributes.
* \return Whether it is depthwise_conv3d.
*/
inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param,
const Layout& kernel_layout) {
static const Layout kOIHW("OIHW");
const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW);
template <typename ATTRS>
inline bool IsDepthwiseConv(const Call& call, ATTRS param, const Layout& kernel_layout) {
static const Layout kOIXX =
backend::IsOp(call.as<CallNode>(), "nn.conv2d") ? Layout("OIHW") : Layout("OIDHW");
const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIXX);
auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1);
}
Expand Down
7 changes: 5 additions & 2 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{"IODHW8i8o", tag::any},
{"ODHWI8o", tag::Odhwi8o},
{"ODHWI16o", tag::Odhwi16o},
{"ODHWI32o", tag::Odhwi32o},
{"ODHWI48o", tag::Odhwi48o},
{"ODHWI64o", tag::Odhwi64o},
};

bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) {
Expand Down Expand Up @@ -342,7 +345,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

if (layout_dict.find(kernel_layout) == layout_dict.end()) {
layout_dict.insert({kernel_layout, tag::any});
LOG(WARNING) << "Unregistered kernel layout for conv: " << data_layout
LOG(WARNING) << "Unregistered kernel layout for conv: " << kernel_layout
<< ", transfer to tag::any";
}

Expand Down Expand Up @@ -382,7 +385,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);

// Covn2d description.
// Conv description.
auto conv_desc =
has_bias ? dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct,
Expand Down
Loading