Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add imports for BERT contrib operators #10949

Merged
merged 17 commits into from
Apr 13, 2022
231 changes: 228 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,22 @@ def flatten_to_nd(x, x_shape, nd=3):
return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype)


def layer_norm(x, eps, gamma, beta):
"""Common function to handle layer norm"""
eps_dtype = infer_type(x).checked_type.dtype

u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
output = _op.multiply(output, gamma)
if beta is not None:
output = _op.add(output, beta)

return output


class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""

Expand Down Expand Up @@ -807,9 +823,10 @@ def _impl_v1(cls, inputs, attr, params):
x = inputs[0]

# Declare consts
half = _expr.const(0.5)
one = _expr.const(1.0)
sqrt2 = _expr.const(math.sqrt(2))
const_dtype = infer_type(x).checked_type.dtype
half = _expr.const(0.5, dtype=const_dtype)
one = _expr.const(1.0, dtype=const_dtype)
sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype)

# Compute gelu
term1 = _op.multiply(half, x)
Expand All @@ -836,6 +853,208 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)


class EmbedLayerNormalization(OnnxOpConverter):
"""Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset.

This layer embeds the input tokens, sums them, and applies layer normalization.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
input_ids = inputs[0]
segment_ids = inputs[1]
word_emb = inputs[2]
pos_emb = inputs[3]
segment_emb = inputs[4]
gamma = inputs[5]
beta = inputs[6]

mask = inputs[7]
altanh marked this conversation as resolved.
Show resolved Hide resolved
pos_ids = inputs[8]

eps = attr.get("epsilon", 1e-12)

(batch_size, seq_len) = infer_shape(input_ids)

if segment_ids:
assert segment_emb

if pos_ids is None:
pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int64")

word_vec = _op.take(word_emb, input_ids, axis=0)
segment_vec = _op.take(segment_emb, segment_ids, axis=0)
pos_vec = _op.take(pos_emb, pos_ids, axis=0)

vec_sum = _op.add(word_vec, pos_vec)
if segment_ids:
vec_sum = _op.add(vec_sum, segment_vec)

ln = layer_norm(vec_sum, eps, gamma, beta)

mask_index = _op.const(np.zeros((batch_size,), dtype="int64"))
if mask:
# calculate number of words per sentence
mask_index = _op.sum(mask, axis=1)

return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3)


class SkipLayerNormalization(OnnxOpConverter):
"""Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset.

This layer sums the two input tensors (along with optional bias), and applies layer
normalization.
"""

@staticmethod
def _compute_layer_norm(x, eps, gamma, beta):
eps_dtype = infer_type(x).checked_type.dtype

u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
output = _op.multiply(output, gamma)
if beta is not None:
output = _op.add(output, beta)

return output

@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
skip = inputs[1]
gamma = inputs[2]
beta = inputs[3]
altanh marked this conversation as resolved.
Show resolved Hide resolved
bias = inputs[4]

eps = attr.get("epsilon", 1e-12)

x = _op.add(data, skip)
if bias is not None:
x = _op.add(x, bias)

output = layer_norm(x, eps, gamma, beta)

# onnxruntime doesn't compute the other outputs, despite the documentation
placeholder = _op.const(0, dtype="float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this placeholder for? optional returns are mean and inverse standard variance right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's true according to the documentation, however both CUDA and C++ onnxruntime implementations of the kernels do not actually ever return or calculate values for these outputs:

https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc


return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)


class Attention(OnnxOpConverter):
"""Operator converter for Attention from Microsoft onnxruntime contrib opset.

This is the self-attention mechanism used in transformer models.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
num_heads = attr["num_heads"]
assert (
"qkv_hidden_sizes" not in attr
), "different hidden sizes for Q, K, V are not currently supported"
assert "unidirectional" not in attr, "unidirectional attention not current supported"

# (batch, seq, in_hidden)
input_emb = inputs[0]

# (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
weight = inputs[1]

# (3 * out_hidden,)
bias = inputs[2]

# 1. ( batch, 1, max_seq, max_seq)
# 2. ( batch, past_seq + seq,)
# 3. ( batch, seq, past_seq + seq,)
# 4. ( batch,)
# 5. (2 * batch,)
# For now, we only support case 2.
mask_index = inputs[3]
altanh marked this conversation as resolved.
Show resolved Hide resolved

# (2, batch, num_heads, past_seq, head_size)
past = inputs[4]

# (batch, num_heads, seq, seq)
extra_add = inputs[5]

(batch_size, seq_len, _) = infer_shape(input_emb)
(out_hidden_x3,) = infer_shape(bias)
assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
out_hidden = out_hidden_x3 // 3
assert (
out_hidden % num_heads == 0
), "output hidden size should be divisible by number of attention heads"
head_size = out_hidden // num_heads

mask_index_shape = infer_shape(mask_index)
assert (
len(mask_index_shape) == 2
and mask_index_shape[0] == batch_size
and mask_index_shape[1] == seq_len
), "currently only support (batch_size, sequence_length) mask index"

assert past is None, "past K, V state is not currently supported"
assert extra_add is None, "extra add to QxK not currently supported"

# split weight and biases and do the matmuls
w_Q, w_K, w_V = _op.split(weight, 3, axis=1)
b_Q, b_K, b_V = _op.split(bias, 3, axis=0)
# need to merge batch dimensions since TVM matmul is 2D
input_emb = _op.reverse_reshape(input_emb, (-1, 0))
Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q)
K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)

# massage tensors in preparation for batched matmul
def massage(tensor):
tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))

# (batch_size, num_heads, seq_len, head_size)
tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])

# (batch_size * num_heads, seq_len, head_size)
return _op.reverse_reshape(tensor, (-1, 0, 0))

Q = massage(Q)
K = massage(K)
V = massage(V)

K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
present = _op.stack([K_present, V_present], axis=0)

att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True)
score_dtype = infer_type(att_scores).checked_type.dtype
att_scores = _op.divide(
att_scores,
_op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype),
)
att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))

# build the attention mask
att_mask = _op.cast(mask_index, score_dtype)
att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))

# apply the mask
att_scores = _op.add(att_scores, att_mask)
att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len))

att_probs = _op.nn.softmax(att_scores, axis=-1)

output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False)
output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
output = _op.transpose(output, axes=[0, 2, 1, 3])
output = _op.reshape(output, (0, 0, out_hidden))

return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)


class Gemm(OnnxOpConverter):
"""Operator converter for Gemm."""

Expand Down Expand Up @@ -4737,6 +4956,12 @@ def _get_convert_map(opset):
"Elu": Elu.get_converter(opset),
"Gelu": Gelu.get_converter(opset),
"BiasGelu": BiasGelu.get_converter(opset),
# TODO: We need a better way to handle different domains, in case
# of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
# are in the `com.microsoft` domain.
"EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset),
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved
"SkipLayerNormalization": SkipLayerNormalization.get_converter(opset),
"Attention": Attention.get_converter(opset),
"Exp": Renamer("exp"),
"Greater": Renamer("greater"),
"GreaterOrEqual": Renamer("greater_equal"),
Expand Down
Loading