Skip to content

Commit

Permalink
[BYOC][DNNL] Enable layer normalization in DNNL byoc. (apache#11508)
Browse files Browse the repository at this point in the history
* Enable layer normalization in DNNL byoc.

* Added unittest for layer norm and make code compatible after introducing TensorRequisite(PR-11345)

* Fix lint issue

* Fix clang format issue
  • Loading branch information
billishyahao authored Jun 8, 2022
1 parent 96a513c commit 9817338
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 1 deletion.
70 changes: 69 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op
from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table

logger = logging.getLogger("DNNL")
Expand Down Expand Up @@ -92,6 +92,7 @@ def _func_wrapper(expr):
_register_external_op_helper("nn.softmax")
_register_external_op_helper("add")
_register_external_op_helper("multiply")
_register_external_op_helper("nn.layer_norm")


def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
Expand Down Expand Up @@ -455,6 +456,7 @@ def visit_call(self, call):
"nn.conv3d",
"nn.conv3d_transpose",
"nn.dense",
"nn.layer_norm",
]
)
if isinstance(call.op, tvm.tir.op.Op):
Expand Down Expand Up @@ -526,3 +528,69 @@ def visit_call(self, call):
new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
new_mod = transform.RemoveUnusedFunctions()(new_mod)
return new_mod


class LayerNormRewrite(DFPatternCallback):
"""
A callback to rewrite the following operators into a single layer normalization operator.
Pattern #1:
1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */;
3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */;
4 %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64), float32] */;
5 %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */;
7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */;
8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */;
9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 3136, 64), float32] */;
10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 3136, 64), float32] */;
Pattern #2:
1 %0 = mean(%input, axis=[-1], keepdims=True);
2 %1 = variance(%input, %0, axis=[-1], keepdims=True);
3 %2 = add(%1, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 49, 1), float32] */;
4 %3 = subtract(%input, %0);
5 %4 = sqrt(%2) /* ty=Tensor[(1, 49, 1), float32] */;
6 %5 = divide(%3, %4);
7 %6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 49, 64), float32] */;
8 %7 = add(%6, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 49, 64), float32] */
"""

def __init__(self):
super(LayerNormRewrite, self).__init__()
self.data = wildcard()
self.gamma = wildcard()
self.beta = wildcard()
mu = is_op("mean")(self.data)
diff = is_op("subtract")(self.data, mu)
cdiff = diff | is_op("cast")(diff)
const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
p1 = is_op("power")(cdiff, const_two)
mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
eps = is_expr(relay.const(1e-5))
added_eps = is_op("add")(mp1, eps)
deno = is_op("sqrt")(added_eps)
div_out = is_op("divide")(diff, deno)
weighted = is_op("multiply")(div_out, self.gamma)
added_bias = is_op("add")(weighted, self.beta)
self.pattern = added_bias

def callback(self, pre, post, node_map):
data = node_map[self.data][0]
gamma = node_map[self.gamma][0]
beta = node_map[self.beta][0]
return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta)


def rewrite_layer_norm(mod):
"""Rewrite the input graph to replace multiple operators with a TVM native layer normalization
operator so that we can offload them to dnnl layer normalization byoc part.
"""
mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
return mod
47 changes: 47 additions & 0 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
Binary(nid, dnnl::algorithm::binary_add);
} else if ("multiply" == op_name) {
Binary(nid, dnnl::algorithm::binary_mul);
} else if ("nn.layer_norm" == op_name) {
LayerNorm(nid);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand Down Expand Up @@ -449,6 +451,51 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{DNNL_ARG_VARIANCE, var_tr}});
}

void LayerNorm(const size_t& nid) {
auto node = nodes_[nid];

auto src_tr = GetInput(nid, 0);
auto gamma_tr = GetInput(nid, 1);
auto beta_tr = GetInput(nid, 2);
auto dst_tr = GetOutput(nid, 0);

auto axis = GetNodeAttr<int>(node, "axis");
auto epsilon = GetNodeAttr<float>(node, "epsilon");
auto center = GetNodeAttr<bool>(node, "center");
auto scale = GetNodeAttr<bool>(node, "scale");

ICHECK(axis == -1 && center && scale) << "Unimplemented LayerNorm case";

// LN description.
auto lnorm_desc = dnnl::layer_normalization_forward::desc(
dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon,
dnnl::normalization_flags::use_scale_shift);

auto lnorm_prim_desc = dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_);

// Concatenate scale and shift tensors
auto scale_shift_tr = TensorRequisite::AsIs(lnorm_prim_desc.weights_desc(), GenUniqueEid());
auto sc_sh_dims = scale_shift_tr.dims();

ICHECK(sc_sh_dims.size() == 2);
ICHECK(sc_sh_dims[0] == 2);
sc_sh_dims[0] /= 2;
auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze();
auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze();

auto register_copy = [this](const TensorRequisite& src, const TensorRequisite& dst) {
dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, dst.desc());
Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}});
};

register_copy(gamma_tr, scale_tr);
register_copy(beta_tr, shift_tr);

Submit(
dnnl::layer_normalization_forward(lnorm_prim_desc),
{{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}, {DNNL_ARG_SCALE_SHIFT, scale_shift_tr}});
}

void Pooling(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];

Expand Down
21 changes: 21 additions & 0 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
with tvm.transform.PassContext(opt_level=3):
mod = alter_layout_seq(mod)

mod = dnnl.rewrite_layer_norm(mod)

byoc_seq = tvm.transform.Sequential(
[
transform.MergeComposite(dnnl.pattern_table()),
Expand Down Expand Up @@ -454,6 +456,16 @@ def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype
return relay.nn.relu(conv2d_bias_bn), dic, param_lst


def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"):
dic = {"input": x_shape}
param_lst = []
input = relay.var("input", shape=x_shape)
beta = relay.const(np.zeros(x_shape[2]).astype(dtype))
gamma = relay.const(np.ones(x_shape[2]).astype(dtype))
out = relay.nn.layer_norm(input, gamma=gamma, beta=beta)
return out, dic, param_lst


def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype)
sum_data = relay.const(np.random.randint(x_shape).astype(dtype))
Expand Down Expand Up @@ -1032,5 +1044,14 @@ def get_graph():
run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, test_bf16=False)


def test_layer_norm(run_module, dtype="float32"):
x_shape = (1, 49, 64)

ln, dic, param_lst = get_layer_norm(x_shape, dtype=dtype)
ln = tvm.IRModule.from_expr(ln)
config = ln, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 9817338

Please sign in to comment.