Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC-DNNL] enable conv3d->bn folding #10837

Merged
merged 5 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support conv3d bn folding
  • Loading branch information
crazydemo committed Mar 18, 2022
commit 283a95633c3e8805944c9d55f22d9ab35435425c
196 changes: 196 additions & 0 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,118 @@ RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", C
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv2DForwardRewrite);


// Consumer operators
// Conv3D send out requirement of axis folding.
Array<Message> Conv3DForwardPrep(const Call& call, const Message& out_message) {
// TODO(tvm-team) support general data layout
// by transforming weight
const auto* param = call->attrs.as<Conv3DAttrs>();
ICHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));

ICHECK_GE(c_big_axis, 0);
Message none = NullValue<Message>();
// For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework.
// By using a unified layout transformation.
// We only need to change the Prep and Mutate function.
//
// only handle depthwise or full conv3d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv3d = IsDepthwiseConv3D(call, param, kernel_layout);
if (param->groups == 1 || is_depthwise_conv3d) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return {Message(arr, false), none};
}
}
return {none, none};
}

// Conv3D consumes the scale axis during transformation.
Expr Conv3DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
const Message& message) {
// if data do not have scale, normal transform path.
const auto* sdata = new_args[0].as<ScaledExprNode>();
const auto* sweight = new_args[1].as<ScaledExprNode>();
if (sdata == nullptr) return Expr();
if (sweight != nullptr) return Expr();
const auto* param = ref_call->attrs.as<Conv3DAttrs>();
ICHECK(param != nullptr);
Layout data_layout(param->data_layout);
Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
ICHECK_GE(c_big_axis, 0);
int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));

bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
ICHECK(is_simple || is_blocking);

// Check it must be depthwise or full conv3d.
bool is_depthwise_conv3d = IsDepthwiseConv3D(ref_call, param, kernel_layout);
ICHECK(param->groups == 1 || is_depthwise_conv3d);

Expr weight = new_args[1];

// match the ic_axis
// for group conv with simple layout OIDHW
if (param->groups > 1 && !is_depthwise_conv3d) {
if (is_simple) {
const Array<PrimExpr>& weight_shape_ = weight->type_as<TensorTypeNode>()->shape;
auto IC = weight_shape_[1] * param->groups;
Array<PrimExpr> weight_shape = {weight_shape_[0], IC};
weight_shape.insert(weight_shape.end(), weight_shape_.begin() + 2, weight_shape_.end());
Expr scale = ReshapeToMatchAxis(sdata->scale, weight_shape, {big_ki_axis});
weight = Multiply(weight, scale);
}
if (!weight.defined()) return Expr();
// for depthwise conv
} else if (is_depthwise_conv3d) {
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
weight = Multiply(weight, scale);
} else {
weight = Multiply(weight,
ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis}));
if (!weight.defined()) return Expr();
}

} else {
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis});
weight = Multiply(weight, scale);
} else {
weight = Multiply(weight,
ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
{big_ki_axis, small_ki_axis}));
if (!weight.defined()) return Expr();
}
}
// return transformed conv3d
return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}

RELAY_REGISTER_OP("nn.conv3d").set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv3DForwardPrep);

RELAY_REGISTER_OP("nn.conv3d")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", Conv3DForwardRewrite);


// Dense send out requirement of axis folding.
Array<Message> DenseForwardPrep(const Call& call, const Message& out_message) {
return {Message({1}, false), NullValue<Message>()};
Expand Down Expand Up @@ -1018,6 +1130,90 @@ RELAY_REGISTER_OP("nn.conv2d")
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);


// Consumer operators
// Conv3D send out requirement of axis folding.
Message Conv3DBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const auto* param = call->attrs.as<Conv3DAttrs>();
ICHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));

ICHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework.
// By using a unified layout transformation.
// We only need to change the Prep and Mutate function.
//
// only handle depthwise or full conv3d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv3d = IsDepthwiseConv3D(call, param, kernel_layout);
if (param->groups >= 1 || is_depthwise_conv3d) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return Message(arr, false);
}
}
return NullValue<Message>();
}

// Conv3D consumes the scale axis during transformation.
Expr Conv3DBackwardTransform(const Call& call, const Message& message, const Expr& scale,
const BackwardTransformer& transformer) {
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
const auto* param = call->attrs.as<Conv3DAttrs>();
ICHECK(param != nullptr);
Layout kernel_layout(param->kernel_layout);
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
ICHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv3d.
bool is_depthwise_conv3d = IsDepthwiseConv3D(call, param, kernel_layout);
ICHECK(param->groups >= 1 || is_depthwise_conv3d);
bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
ICHECK(is_simple || is_blocking);

Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
// scale on input for deptwise.
Expr wscale;
if (is_simple) {
wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis});
} else {
wscale = ReshapeToMatchAxis(scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis});
if (!wscale.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
}
weight = Multiply(weight, wscale);
return Call(call->op, {data, weight}, call->attrs, call->type_args);
}

RELAY_REGISTER_OP("nn.conv3d")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv3DBackwardPrep);

RELAY_REGISTER_OP("nn.conv3d")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv3DBackwardTransform);


Message BiasAddBackwardPrep(const Call& call, const Array<Message>& in_messages) {
const BiasAddAttrs* attrs = call->attrs.as<BiasAddAttrs>();
ICHECK(attrs);
Expand Down
15 changes: 15 additions & 0 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,21 @@ inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param,
return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1);
}

/*!
* \brief Check if the call is depthwise conv3d.
*
* \param call The conv3d call.
* \param param The conv3d attributes.
* \return Whether it is depthwise_conv3d.
*/
inline bool IsDepthwiseConv3D(const Call& call, const Conv3DAttrs* param,
const Layout& kernel_layout) {
static const Layout kOIDHW("OIDHW");
const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIDHW);
auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1);
}

/*!
* \brief Get super-dimension of output channels of conv2d
* \param call The conv2d call.
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{"IODHW8i8o", tag::any},
{"ODHWI8o", tag::Odhwi8o},
{"ODHWI16o", tag::Odhwi16o},
{"ODHWI32o", tag::Odhwi32o},
{"ODHWI48o", tag::Odhwi48o},
{"ODHWI64o", tag::Odhwi64o},
};

bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) {
Expand Down Expand Up @@ -382,7 +385,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);

// Covn2d description.
// Conv description.
auto conv_desc =
has_bias ? dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct,
Expand Down