-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Add imports for BERT contrib operators #10949
Changes from 1 commit
c4b3e16
4190bb2
1d3064e
1927414
b718d6a
90bb12f
4998492
768e535
29e0c68
dbb7df3
265e753
43296f9
de7d940
93aceb2
dfba87e
989b412
16a4d09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -874,14 +874,7 @@ def _impl_v1(cls, inputs, attr, params): | |
if segment_ids: | ||
vec_sum = _op.add(vec_sum, segment_vec) | ||
|
||
eps_dtype = infer_type(word_emb).checked_type.dtype | ||
|
||
u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True) | ||
ln = _op.divide( | ||
_op.subtract(vec_sum, u), | ||
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), | ||
) | ||
ln = _op.multiply(ln, gamma) + beta | ||
ln = SkipLayerNormalization._compute_layer_norm(vec_sum, eps, gamma, beta) | ||
|
||
mask_index = _op.const(np.zeros((batch_size,), dtype="int64")) | ||
if mask: | ||
|
@@ -898,6 +891,21 @@ class SkipLayerNormalization(OnnxOpConverter): | |
normalization. | ||
""" | ||
|
||
@staticmethod | ||
def _compute_layer_norm(x, eps, gamma, beta): | ||
eps_dtype = infer_type(x).checked_type.dtype | ||
|
||
u, s = _op.mean_variance(x, axis=-1, keepdims=True) | ||
output = _op.divide( | ||
_op.subtract(x, u), | ||
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), | ||
) | ||
output = _op.multiply(output, gamma) | ||
if beta is not None: | ||
output = _op.add(output, beta) | ||
|
||
return output | ||
|
||
@classmethod | ||
def _impl_v1(cls, inputs, attr, params): | ||
data = inputs[0] | ||
|
@@ -912,17 +920,9 @@ def _impl_v1(cls, inputs, attr, params): | |
if bias is not None: | ||
x = _op.add(x, bias) | ||
|
||
eps_dtype = infer_type(x).checked_type.dtype | ||
|
||
u, s = _op.mean_variance(x, axis=-1, keepdims=True) | ||
output = _op.divide( | ||
_op.subtract(x, u), | ||
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), | ||
) | ||
output = _op.multiply(output, gamma) | ||
if beta: | ||
output = _op.add(output, beta) | ||
output = SkipLayerNormalization._compute_layer_norm(x, eps, gamma, beta) | ||
|
||
# onnxruntime doesn't compute the other outputs, despite the documentation | ||
placeholder = _op.const(0, dtype="float32") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this placeholder for? optional returns are mean and inverse standard variance right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's true according to the documentation, however both CUDA and C++ onnxruntime implementations of the kernels do not actually ever return or calculate values for these outputs: https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc |
||
|
||
return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe instead of referencing
SkipLayerNormalization
here, you could create aLayerNormalization
base class that contains_compute_layer_norm
? sort of like howPool
is the base class forMaxPool
/AveragePool
etc?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redefining _compute_layer_norm as a global func in this file -- won't put it in a separate class as semantically LayerNorm is not an onnx operator.