Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add imports for BERT contrib operators #10949

Merged
merged 17 commits into from
Apr 13, 2022
Prev Previous commit
Next Next commit
fix Attention
  • Loading branch information
altanh committed Apr 8, 2022
commit 4190bb2ebc5b84e5397386000788955e7fbda537
72 changes: 46 additions & 26 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ def _impl_v1(cls, inputs, attr, params):
mask = inputs[7]
altanh marked this conversation as resolved.
Show resolved Hide resolved
pos_ids = inputs[8]

eps = attr["epsilon"] if "epsilon" in attr else 0.0

(batch_size, seq_len) = infer_shape(input_ids)

if segment_ids:
Expand All @@ -869,7 +871,10 @@ def _impl_v1(cls, inputs, attr, params):
eps_dtype = infer_type(word_emb).checked_type.dtype

u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True)
ln = _op.divide(_op.subtract(vec_sum, u), _op.sqrt(_op.add(s, _op.const(attr["epsilon"], dtype=eps_dtype))))
ln = _op.divide(
_op.subtract(vec_sum, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
ln = _op.multiply(ln, gamma) + beta

# TODO: actually calculate this
Expand All @@ -888,8 +893,10 @@ class Attention(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
num_heads = attr["num_heads"]
assert "qkv_hidden_sizes" not in attr
assert "unidirectional" not in attr
assert (
"qkv_hidden_sizes" not in attr
), "different hidden sizes for Q, K, V are not currently supported"
assert "unidirectional" not in attr, "unidirectional attention not current supported"

# (batch, seq, in_hidden)
input_emb = inputs[0]
Expand All @@ -905,6 +912,7 @@ def _impl_v1(cls, inputs, attr, params):
# 3. ( batch, seq, past_seq + seq,)
# 4. ( batch,)
# 5. (2 * batch,)
# For now, we only support case 2.
mask_index = inputs[3]
altanh marked this conversation as resolved.
Show resolved Hide resolved

# (2, batch, num_heads, past_seq, head_size)
Expand All @@ -915,44 +923,56 @@ def _impl_v1(cls, inputs, attr, params):

(batch_size, seq_len, in_hidden) = infer_shape(input_emb)
(out_hidden_x3,) = infer_shape(bias)
assert out_hidden_x3 % 3 == 0
assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
out_hidden = out_hidden_x3 // 3
assert out_hidden % num_heads == 0
assert (
out_hidden % num_heads == 0
), "output hidden size should be divisible by number of attention heads"
head_size = out_hidden // num_heads

mask_index_shape = infer_shape(mask_index)
assert len(mask_index_shape) == 2
assert mask_index_shape[0] == batch_size
assert mask_index_shape[1] == seq_len
assert (
len(mask_index_shape) == 2
and mask_index_shape[0] == batch_size
and mask_index_shape[1] == seq_len
), "currently only support (batch_size, sequence_length) mask index"

assert past is None
assert extra_add is None
assert past is None, "past K, V state is not currently supported"
assert extra_add is None, "extra add to QxK not currently supported"

# decompose weight into Q, K, V: (in_hidden, out_hidden) and do the matmuls
# split weight and biases and do the matmuls
w_Q, w_K, w_V = _op.split(weight, 3, axis=1)

Q = _op.nn.matmul(input_emb, w_Q)
K = _op.nn.matmul(input_emb, w_K)
V = _op.nn.matmul(input_emb, w_V)
b_Q, b_K, b_V = _op.split(bias, 3, axis=0)
# need to merge batch dimensions since TVM matmul is 2D
input_emb = _op.reverse_reshape(input_emb, (-1, 0))
Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q)
K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)

# massage tensors in preparation for batched matmul
def massage(tensor, is_V=False):
axes = [0, 2, 3, 1] if is_V else [0, 2, 1, 3]
def massage(tensor):
tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))
tensor = _op.transpose(tensor, axes=axes)

# (batch_size, num_heads, seq_len, head_size)
tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])

# (batch_size * num_heads, seq_len, head_size)
return _op.reverse_reshape(tensor, (-1, 0, 0))

Q = massage(Q)
K = massage(K)
V = massage(V, is_V=True)
V = massage(V)

K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
present = _op.stack([K_present, V_present], axis=0)

att_scores = _op.nn.batch_matmul(Q, K)
att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True)
score_dtype = infer_type(att_scores).checked_type.dtype
att_scores = _op.divide(att_scores, _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype))
att_scores = _op.divide(
att_scores,
_op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype),
)
att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))

# build the attention mask
Expand All @@ -967,12 +987,12 @@ def massage(tensor, is_V=False):

att_probs = _op.nn.softmax(att_scores, axis=-1)

C = _op.nn.batch_matmul(att_probs, V)
C = _op.reverse_reshape(C, (-1, num_heads, 0, 0))
C = _op.transpose(C, axes=[0, 2, 1, 3])
C = _op.reshape(C, (0, 0, out_hidden))
output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False)
output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
output = _op.transpose(output, axes=[0, 2, 1, 3])
output = _op.reshape(output, (0, 0, out_hidden))

return _expr.TupleWrapper(_expr.Tuple([C, present]), 2)
return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)


class Gemm(OnnxOpConverter):
Expand Down