Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add imports for BERT contrib operators #10949

Merged
merged 17 commits into from
Apr 13, 2022
Prev Previous commit
Next Next commit
factor out layer norm computation
  • Loading branch information
altanh committed Apr 10, 2022
commit dbb7df3bbc92a29fdfbb5a89fa576d49bae8f82c
36 changes: 18 additions & 18 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,14 +874,7 @@ def _impl_v1(cls, inputs, attr, params):
if segment_ids:
vec_sum = _op.add(vec_sum, segment_vec)

eps_dtype = infer_type(word_emb).checked_type.dtype

u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True)
ln = _op.divide(
_op.subtract(vec_sum, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
ln = _op.multiply(ln, gamma) + beta
ln = SkipLayerNormalization._compute_layer_norm(vec_sum, eps, gamma, beta)
Copy link
Contributor

@margaretqian margaretqian Apr 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe instead of referencing SkipLayerNormalization here, you could create a LayerNormalization base class that contains _compute_layer_norm? sort of like how Pool is the base class for MaxPool/AveragePool etc?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redefining _compute_layer_norm as a global func in this file -- won't put it in a separate class as semantically LayerNorm is not an onnx operator.


mask_index = _op.const(np.zeros((batch_size,), dtype="int64"))
if mask:
Expand All @@ -898,6 +891,21 @@ class SkipLayerNormalization(OnnxOpConverter):
normalization.
"""

@staticmethod
def _compute_layer_norm(x, eps, gamma, beta):
eps_dtype = infer_type(x).checked_type.dtype

u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
output = _op.multiply(output, gamma)
if beta is not None:
output = _op.add(output, beta)

return output

@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
Expand All @@ -912,17 +920,9 @@ def _impl_v1(cls, inputs, attr, params):
if bias is not None:
x = _op.add(x, bias)

eps_dtype = infer_type(x).checked_type.dtype

u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
output = _op.multiply(output, gamma)
if beta:
output = _op.add(output, beta)
output = SkipLayerNormalization._compute_layer_norm(x, eps, gamma, beta)

# onnxruntime doesn't compute the other outputs, despite the documentation
placeholder = _op.const(0, dtype="float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this placeholder for? optional returns are mean and inverse standard variance right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's true according to the documentation, however both CUDA and C++ onnxruntime implementations of the kernels do not actually ever return or calculate values for these outputs:

https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc


return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)
Expand Down