Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add imports for BERT contrib operators #10949

Merged
merged 17 commits into from
Apr 13, 2022
196 changes: 193 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,10 @@ def _impl_v1(cls, inputs, attr, params):
x = inputs[0]

# Declare consts
half = _expr.const(0.5)
one = _expr.const(1.0)
sqrt2 = _expr.const(math.sqrt(2))
const_dtype = infer_type(x).checked_type.dtype
half = _expr.const(0.5, dtype=const_dtype)
one = _expr.const(1.0, dtype=const_dtype)
sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype)

# Compute gelu
term1 = _op.multiply(half, x)
Expand All @@ -836,6 +837,192 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)


class EmbedLayerNormalization(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
input_ids = inputs[0]
segment_ids = inputs[1]
word_emb = inputs[2]
pos_emb = inputs[3]
segment_emb = inputs[4]
gamma = inputs[5]
beta = inputs[6]

mask = inputs[7]
altanh marked this conversation as resolved.
Show resolved Hide resolved
pos_ids = inputs[8]

eps = attr["epsilon"] if "epsilon" in attr else 1e-12
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved

(batch_size, seq_len) = infer_shape(input_ids)

if segment_ids:
assert segment_emb

if pos_ids is None:
pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int64")

word_vec = _op.take(word_emb, input_ids, axis=0)
segment_vec = _op.take(segment_emb, segment_ids, axis=0)
pos_vec = _op.take(pos_emb, pos_ids, axis=0)

vec_sum = _op.add(word_vec, pos_vec)
if segment_ids:
vec_sum = _op.add(vec_sum, segment_vec)

eps_dtype = infer_type(word_emb).checked_type.dtype

u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True)
ln = _op.divide(
_op.subtract(vec_sum, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
ln = _op.multiply(ln, gamma) + beta

mask_index = _op.const(np.zeros((batch_size,), dtype="int64"))
if mask:
# calculate number of words per sentence
mask_index = _op.sum(mask, axis=1)

return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3)


class SkipLayerNormalization(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
skip = inputs[1]
gamma = inputs[2]
beta = inputs[3]
altanh marked this conversation as resolved.
Show resolved Hide resolved
bias = inputs[4]

eps = attr["epsilon"] if "epsilon" in attr else 1e-12

x = _op.add(data, skip)
if bias is not None:
x = _op.add(x, bias)

eps_dtype = infer_type(x).checked_type.dtype

u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
output = _op.multiply(output, gamma)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is basically same normalization calculation as 877-886 above right? if it's easy, can we pull it out into a common helper function?

if beta:
output = _op.add(output, beta)

placeholder = _op.const(0, dtype="float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this placeholder for? optional returns are mean and inverse standard variance right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's true according to the documentation, however both CUDA and C++ onnxruntime implementations of the kernels do not actually ever return or calculate values for these outputs:

https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc


return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)


class Attention(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
num_heads = attr["num_heads"]
assert (
"qkv_hidden_sizes" not in attr
), "different hidden sizes for Q, K, V are not currently supported"
assert "unidirectional" not in attr, "unidirectional attention not current supported"

# (batch, seq, in_hidden)
input_emb = inputs[0]

# (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
weight = inputs[1]

# (3 * out_hidden,)
bias = inputs[2]

# 1. ( batch, 1, max_seq, max_seq)
# 2. ( batch, past_seq + seq,)
# 3. ( batch, seq, past_seq + seq,)
# 4. ( batch,)
# 5. (2 * batch,)
# For now, we only support case 2.
mask_index = inputs[3]
altanh marked this conversation as resolved.
Show resolved Hide resolved

# (2, batch, num_heads, past_seq, head_size)
past = inputs[4]

# (batch, num_heads, seq, seq)
extra_add = inputs[5]

(batch_size, seq_len, in_hidden) = infer_shape(input_emb)
(out_hidden_x3,) = infer_shape(bias)
assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
out_hidden = out_hidden_x3 // 3
assert (
out_hidden % num_heads == 0
), "output hidden size should be divisible by number of attention heads"
head_size = out_hidden // num_heads

mask_index_shape = infer_shape(mask_index)
assert (
len(mask_index_shape) == 2
and mask_index_shape[0] == batch_size
and mask_index_shape[1] == seq_len
), "currently only support (batch_size, sequence_length) mask index"

assert past is None, "past K, V state is not currently supported"
assert extra_add is None, "extra add to QxK not currently supported"

# split weight and biases and do the matmuls
w_Q, w_K, w_V = _op.split(weight, 3, axis=1)
b_Q, b_K, b_V = _op.split(bias, 3, axis=0)
# need to merge batch dimensions since TVM matmul is 2D
input_emb = _op.reverse_reshape(input_emb, (-1, 0))
Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q)
K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)

# massage tensors in preparation for batched matmul
def massage(tensor):
tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))

# (batch_size, num_heads, seq_len, head_size)
tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])

# (batch_size * num_heads, seq_len, head_size)
return _op.reverse_reshape(tensor, (-1, 0, 0))

Q = massage(Q)
K = massage(K)
V = massage(V)

K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
present = _op.stack([K_present, V_present], axis=0)

att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True)
score_dtype = infer_type(att_scores).checked_type.dtype
att_scores = _op.divide(
att_scores,
_op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype),
)
att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))

# build the attention mask
att_mask = _op.cast(mask_index, score_dtype)
att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))

# apply the mask
att_scores = _op.add(att_scores, att_mask)
att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len))

att_probs = _op.nn.softmax(att_scores, axis=-1)

output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False)
output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
output = _op.transpose(output, axes=[0, 2, 1, 3])
output = _op.reshape(output, (0, 0, out_hidden))

return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)


class Gemm(OnnxOpConverter):
"""Operator converter for Gemm."""

Expand Down Expand Up @@ -4737,6 +4924,9 @@ def _get_convert_map(opset):
"Elu": Elu.get_converter(opset),
"Gelu": Gelu.get_converter(opset),
"BiasGelu": BiasGelu.get_converter(opset),
"EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset),
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved
"SkipLayerNormalization": SkipLayerNormalization.get_converter(opset),
"Attention": Attention.get_converter(opset),
"Exp": Renamer("exp"),
"Greater": Renamer("greater"),
"GreaterOrEqual": Renamer("greater_equal"),
Expand Down
Loading