Skip to content

Commit

Permalink
[BYOC-DNNL] enable conv3d->bn folding (apache#10837)
Browse files Browse the repository at this point in the history
* support conv3d bn folding

* add test case for fold_scale_axis

* modify lint

* remove test cases

* unify conv2d 3d impls, and add test cases.
  • Loading branch information
crazydemo authored and altanh committed Apr 28, 2022
1 parent 9e6edda commit 6c404a0
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 39 deletions.
108 changes: 79 additions & 29 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/transform.h>
#include <tvm/tir/data_layout.h>

#include "../backend/utils.h"
#include "../op/tensor/transform.h"
#include "pass_utils.h"
#include "pattern_utils.h"
Expand Down Expand Up @@ -492,11 +493,11 @@ RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", MultiplyForwardRewrite);

// Consumer operators
// Conv2D send out requirement of axis folding.
Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// Conv send out requirement of axis folding.
template <typename ATTRS>
Array<Message> ConvForwardPrep(const Call& call, const ATTRS* param, const Message& out_message) {
// TODO(tvm-team) support general data layout
// by transforming weight
const auto* param = call->attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
Expand All @@ -512,8 +513,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (param->groups == 1 || is_depthwise_conv2d) {
bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
if (param->groups == 1 || is_depthwise_conv) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
Expand All @@ -529,14 +530,14 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
}

// Conv2D consumes the scale axis during transformation.
Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
const Message& message) {
template <typename ATTRS>
Expr ConvForwardRewrite(const Call& ref_call, const ATTRS* param, const Array<Expr>& new_args,
const Message& message) {
// if data do not have scale, normal transform path.
const auto* sdata = new_args[0].as<ScaledExprNode>();
const auto* sweight = new_args[1].as<ScaledExprNode>();
if (sdata == nullptr) return Expr();
if (sweight != nullptr) return Expr();
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
Expand All @@ -552,13 +553,13 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
ICHECK(is_simple || is_blocking);

// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
ICHECK(param->groups == 1 || is_depthwise_conv2d);
bool is_depthwise_conv = IsDepthwiseConv(ref_call, param, kernel_layout);
ICHECK(param->groups == 1 || is_depthwise_conv);

Expr weight = new_args[1];

// match the ic_axis
if (is_depthwise_conv2d) {
if (is_depthwise_conv) {
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
weight = Multiply(weight, scale);
Expand All @@ -580,14 +581,38 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
if (!weight.defined()) return Expr();
}
}
// return transformed conv2d
// return transformed conv
return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}

RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
Array<Message> PreConvForwardPrep(const Call& call, const Message& out_message) {
if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
const auto* param = call->attrs.as<Conv2DAttrs>();
return ConvForwardPrep(call, param, out_message);
}
const auto* param = call->attrs.as<Conv3DAttrs>();
return ConvForwardPrep(call, param, out_message);
}

Expr PreConvForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
const Message& message) {
if (backend::IsOp(ref_call.as<CallNode>(), "nn.conv2d")) {
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
return ConvForwardRewrite(ref_call, param, new_args, message);
}
const auto* param = ref_call->attrs.as<Conv3DAttrs>();
return ConvForwardRewrite(ref_call, param, new_args, message);
}

RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);

RELAY_REGISTER_OP("nn.conv3d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", PreConvForwardPrep);

RELAY_REGISTER_OP("nn.conv3d")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", PreConvForwardRewrite);

// Dense send out requirement of axis folding.
Array<Message> DenseForwardPrep(const Call& call, const Message& out_message) {
Expand Down Expand Up @@ -937,9 +962,9 @@ RELAY_REGISTER_OP("multiply")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);

// Consumer operators
// Conv2D send out requirement of axis folding.
Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* param = call->attrs.as<Conv2DAttrs>();
// Conv send out requirement of axis folding.
template <typename ATTRS>
Message ConvBackwardPrep(const Call& call, const ATTRS* param, const Array<Message>& in_messages) {
ICHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
Expand All @@ -952,10 +977,10 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// By using a unified layout transformation.
// We only need to change the Prep and Mutate function.
//
// only handle depthwise or full conv2d.
// only handle depthwise or full conv.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (param->groups == 1 || is_depthwise_conv2d) {
bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
if (param->groups == 1 || is_depthwise_conv) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
Expand All @@ -970,13 +995,13 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
return NullValue<Message>();
}

// Conv2D consumes the scale axis during transformation.
Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
// Conv consumes the scale axis during transformation.
template <typename ATTRS>
Expr ConvBackwardTransform(const Call& call, const ATTRS* param, const Message& message,
const Expr& scale, const BackwardTransformer& transformer) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
const auto* param = call->attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
Expand All @@ -988,9 +1013,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
ICHECK(param->groups == 1 || is_depthwise_conv2d);
// Check it must be depthwise or full conv.
bool is_depthwise_conv = IsDepthwiseConv(call, param, kernel_layout);
ICHECK(param->groups == 1 || is_depthwise_conv);
bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
ICHECK(is_simple || is_blocking);
Expand All @@ -1012,11 +1037,36 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
return Call(call->op, {data, weight}, call->attrs, call->type_args);
}

Message PreConvBackwardPrep(const Call& call, const Array<Message>& in_messages) {
if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
const auto* param = call->attrs.as<Conv2DAttrs>();
return ConvBackwardPrep(call, param, in_messages);
}
const auto* param = call->attrs.as<Conv3DAttrs>();
return ConvBackwardPrep(call, param, in_messages);
}

Expr PreConvBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
if (backend::IsOp(call.as<CallNode>(), "nn.conv2d")) {
const auto* param = call->attrs.as<Conv2DAttrs>();
return ConvBackwardTransform(call, param, message, scale, transformer);
}
const auto* param = call->attrs.as<Conv3DAttrs>();
return ConvBackwardTransform(call, param, message, scale, transformer);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);

RELAY_REGISTER_OP("nn.conv3d")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", PreConvBackwardPrep);

RELAY_REGISTER_OP("nn.conv3d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", PreConvBackwardTransform);

Message BiasAddBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const BiasAddAttrs* attrs = call->attrs.as<BiasAddAttrs>();
Expand Down
18 changes: 10 additions & 8 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include <utility>
#include <vector>

#include "../backend/utils.h"
#include "../op/make_op.h"

namespace tvm {
Expand Down Expand Up @@ -183,16 +184,17 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array<Intege
}

/*!
* \brief Check if the call is depthwise conv2d.
* \brief Check if the call is depthwise conv3d.
*
* \param call The conv2d call.
* \param param The conv2d attributes.
* \return Whether it is depthwise_conv2d.
* \param call The conv call.
* \param param The conv attributes.
* \return Whether it is depthwise_conv3d.
*/
inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param,
const Layout& kernel_layout) {
static const Layout kOIHW("OIHW");
const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW);
template <typename ATTRS>
inline bool IsDepthwiseConv(const Call& call, ATTRS param, const Layout& kernel_layout) {
static const Layout kOIXX =
backend::IsOp(call.as<CallNode>(), "nn.conv2d") ? Layout("OIHW") : Layout("OIDHW");
const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIXX);
auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1);
}
Expand Down
7 changes: 5 additions & 2 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{"IODHW8i8o", tag::any},
{"ODHWI8o", tag::Odhwi8o},
{"ODHWI16o", tag::Odhwi16o},
{"ODHWI32o", tag::Odhwi32o},
{"ODHWI48o", tag::Odhwi48o},
{"ODHWI64o", tag::Odhwi64o},
};

bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) {
Expand Down Expand Up @@ -342,7 +345,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

if (layout_dict.find(kernel_layout) == layout_dict.end()) {
layout_dict.insert({kernel_layout, tag::any});
LOG(WARNING) << "Unregistered kernel layout for conv: " << data_layout
LOG(WARNING) << "Unregistered kernel layout for conv: " << kernel_layout
<< ", transfer to tag::any";
}

Expand Down Expand Up @@ -382,7 +385,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);

// Covn2d description.
// Conv description.
auto conv_desc =
has_bias ? dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct,
Expand Down
Loading

0 comments on commit 6c404a0

Please sign in to comment.