Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add imports for BERT contrib operators #10949

Merged
merged 17 commits into from
Apr 13, 2022

Conversation

altanh
Copy link
Contributor

@altanh altanh commented Apr 8, 2022

  • Add imports for Attention, EmbedLayerNormalization, SkipLayerNormalization
  • Fix small dtype bug in Gelu import

cc @AndrewZhaoLuo @margaretqian @sfvaroglu

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have reference implementation of these operators? Did not look too closely at impl but a few comments.

python/tvm/relay/frontend/onnx.py Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Show resolved Hide resolved
Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM generally, but would appreciate another pair of eyes @margaretqian @sfvaroglu

tests/python/frontend/onnx/test_forward.py Show resolved Hide resolved
if beta:
output = _op.add(output, beta)

placeholder = _op.const(0, dtype="float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this placeholder for? optional returns are mean and inverse standard variance right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's true according to the documentation, however both CUDA and C++ onnxruntime implementations of the kernels do not actually ever return or calculate values for these outputs:

https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc

Comment on lines 915 to 922
eps_dtype = infer_type(x).checked_type.dtype

u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
output = _op.multiply(output, gamma)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is basically same normalization calculation as 877-886 above right? if it's easy, can we pull it out into a common helper function?

if segment_ids:
vec_sum = _op.add(vec_sum, segment_vec)

ln = SkipLayerNormalization._compute_layer_norm(vec_sum, eps, gamma, beta)
Copy link
Contributor

@margaretqian margaretqian Apr 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe instead of referencing SkipLayerNormalization here, you could create a LayerNormalization base class that contains _compute_layer_norm? sort of like how Pool is the base class for MaxPool/AveragePool etc?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redefining _compute_layer_norm as a global func in this file -- won't put it in a separate class as semantically LayerNorm is not an onnx operator.

Copy link
Contributor

@margaretqian margaretqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@AndrewZhaoLuo AndrewZhaoLuo merged commit 11b8cd3 into apache:main Apr 13, 2022
AndrewZhaoLuo added a commit to AndrewZhaoLuo/tvm that referenced this pull request Apr 15, 2022
* main: (527 commits)
  [hexagon] 'add_hvx' test to explore HVX usage. (apache#10604)
  [COMMUNITY] @yzh119 -> Reviewer (apache#10993)
  [Metaschedule] Make custom schedule_rule registration optional (apache#10975)
  [ONNX] Add imports for BERT contrib operators (apache#10949)
  sort axes (apache#10985)
  [Hexagon] Remove HexagonBuffer external constructor and support (apache#10978)
  [CI] Update GPU image (apache#10992)
  [Runtime][Vulkan] Add RGP support to TVM for vulkan device (apache#10953)
  [FIX] resolve int64/32 for AttrStmtNode (apache#10983)
  [TVMC] Allow output module name to be passed as a command line argument (apache#10962)
  [ONNX] Add MatMulInteger importer (apache#10450)
  [COMMUNITY] @guberti -> Reviewer (apache#10976)
  Support `qnn.conv2d` in FoldExplicitPading (apache#10982)
  change Hexagon docker version (apache#10981)
  remove exception handling of autotvm xgboost extract functions (apache#10948)
  [CUDNN] Add partitioning support for conv2d and log_softmax (apache#10961)
  [Hexagon][LLVM] Enable/test tensorized Hexagon DMA on 2d transformed layout (apache#10905)
  [Hexagon] Move aot/graph_executor interactions into launcher (apache#10907)
  [HEXAGON] Split huge 1D DMA Transfers into smaller transfers with legal sizes. (apache#10971)
  [CI][DOCKER] Add pytest-lazy-fixture to images (apache#10970)
  ...
Lucien0 pushed a commit to Lucien0/tvm that referenced this pull request Apr 19, 2022
* EmbedLayerNormalization, Attention

* fix Attention

* SkipLayerNormalization

* fix dtype bug in Gelu

Co-authored-by: An Wang <anwang2009@gmail.com>

* missing parameterize_targets

* lint

* lint

* comments

* fix small thing

* factor out layer norm computation

* layernorm func

* add optional args to test

* upgrade onnxrt version

* no upgrade onnx

* fix tests

* int32

* fix tests

Co-authored-by: An Wang <anwang2009@gmail.com>
altanh added a commit to altanh/tvm that referenced this pull request Apr 28, 2022
* EmbedLayerNormalization, Attention

* fix Attention

* SkipLayerNormalization

* fix dtype bug in Gelu

Co-authored-by: An Wang <anwang2009@gmail.com>

* missing parameterize_targets

* lint

* lint

* comments

* fix small thing

* factor out layer norm computation

* layernorm func

* add optional args to test

* upgrade onnxrt version

* no upgrade onnx

* fix tests

* int32

* fix tests

Co-authored-by: An Wang <anwang2009@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants