-
Notifications
You must be signed in to change notification settings - Fork 102
Add Attention Microsoft Contrib Operator #3816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
should be the main driver of the attention head here
6ec57c7
to
01012e8
Compare
breaking this up to smaller pieces for optional args as I populate the proper vector inputs before tying things to the calculation and creation of multi head attention layers
need to finish with other input args and check infered and parsed attributes.
ef0e053
to
4ef0ba7
Compare
split this up to clean up the handle_inputs call and seperate errors/state when we aquire attributes
…nput correctly. Need to fill in parser piece but this checks and ensures we're working with the proper batch size for our calculations within the attention head. Debug still needs to be removedb but this ensures we're seeing the proper amount of heads that are batched correctly
…r now. add some sort of tracked state for padding modes of the mask_index for input linear layer masking prior to attention head splits.
Too much Cpp too little python
Tests more representative of customer workloads and models we see in the wild. Need to finish these to complete parseer tests. Will add tests for other inputs and error cases later
Give an explanation to how things are parsed in as the input sizes of masks, inputs, weights as well as attributes can change how certain infered values in the parser can be calculated. This is due to how the spec specifices how inputs will be handled on parse.
clean up debug from input_linear_to_qkv and have input be put in via vector if instructions.
32f91e4
to
a7f8a3b
Compare
Need to split the input to based on num_heads.
…cat and slice from Q,K,V mats This should theoretically reduce the overhead from slices and fuse things easier. Previous changes caused our parser to generate a large amount of slice/dot operations for each attention head. By adding the head dimension instead of slicing out the pieces and handling those instructions we reduce parser time but also simplify how things should get fused as well as the code for this. This also removes the need to add splits for the mask and modifies on the input mat before the softmax. Still seeing a large slowdown still. Vs Previous parser getting 142ms (225 QPS) on target model vs the previus parser's 129 (250 QPS). It appears we're not fusing the attention pieces correctly still
Adding change that gets rid of the concat for now based on @pfultz2 's recommendation. Seeing a speedup with this when using MLIR fusion on an attention block. Mirroed a model workload with just attention in a gen_onnx.py test and we're getting a significant speedup. Before with no MLIR flags
After with
Vs prior to removing the concat Flags
No flags
|
Adding in code from Ahsan to handle in the qkv_hidden_size
This allows us to also fuse teh slice block for each matrix instead and simplifies the logic for the input linear layer
Found a way to get the slice fused into the attention head as well if I handle the weight dot product before the slice block. Simplied the logic too. Getting a slight speedup per attention head by doing this, leaving only the GEMM input being the issue now. I'm able to now let us use the
|
Still need to ad case for when buffers aren't shared. Need to stack KV matrix output
Need to get changes sorted as attention isn't fusing anymore when using this PR in testing - #3993 We want this to be supported as it reduces the amount of larger dot/GEMMs and reduces our GEMM times from 98+ms -> ~75ms Meaning this gives us about a 20-25% boost should attention fuse correctly. |
…tests This is done as this operator has a bunch of attributes that are optional apart from num_heads. We need to specify a different flavor of parameters to ensure things are parsing/working as expected
…arious inputs - Will reuse these for verify - Made smaller test case for inputs so easier to get verify data later.
Got fusion sorted for this and benchmarked things using Customer script. Saw a 21% speedup using Paul's change. Added past/present input outputs and omnidirectional leaving rotary input encoding out for now to just get this in. I can open a issue for this. All I have left to do is parser/verify test. I've reduced the size for the attention_double_head for the verify test simplicity (batch 2, seq 4, hidden 4, etc) The math should workout to be the same once this is scaled up. |
Use this test data for verification (WIP) simplify tests to test various cases may or manynot import more OnnxRT test cases for this. Getting fails on this right now for gold data, Seeing inconsistent output.
This build is not recommended to merge 🔴 |
❌bert-mrpc-tf: ERROR - check error output2025-05-28 04:26:17.660151: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1748424383.360979 184202 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 62973 MB memory: -> device: 0, name: AMD Instinct MI250X/MI250, pci bus id: 0000:32:00.0 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1748424384.268987 184202 mlir_graph_optimization_pass.cc:401] MLIR V1 optimization pass is not enabled 2025-05-28 04:26:34.521466: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc 2025-05-28 04:26:34.521545: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc 2025-05-28 04:26:34.521591: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc 2025-05-28 04:26:34.522017: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc 2025-05-28 04:26:34.522045: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc 2025-05-28 04:26:34.522085: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc 2025-05-28 04:26:34.522125: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc 2025-05-28 04:26:34.522164: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc error: Failure when generating HSACO error: Failure when generating HSACO error: Failure when generating HSACO error: Failure when generating HSACO error: Failure when generating HSACO error: Failure when generating HSACO error: Failure when generating HSACO error: Failure when generating HSACO 2025-05-28 04:26:34.523161: E tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc:228] INTERNAL: Generating device code failed. 2025-05-28 04:26:34.524697: W tensorflow/core/framework/op_kernel.cc:1829] UNKNOWN: JIT compilation failed. 2025-05-28 04:26:34.524719: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: UNKNOWN: JIT compilation failed. [[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]] 2025-05-28 04:26:34.524731: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: UNKNOWN: JIT compilation failed. [[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]] [[import/loss/output/_21]] 2025-05-28 04:26:34.524748: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11217777527359497193 Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1407, in _do_call return fn(*args) File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1390, in _run_fn return self._call_tf_sessionrun(options, feed_dict, fetch_list, File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1483, in _call_tf_sessionrun return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict, tensorflow.python.framework.errors_impl.UnknownError: 2 root error(s) found. (0) UNKNOWN: JIT compilation failed. [[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]] [[import/loss/output/_21]] (1) UNKNOWN: JIT compilation failed. [[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]] 0 successful operations. 0 derived errors ignored. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in main() File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 324, in main y_out = sess.run(y, feed_dict=tf_dict) File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 977, in run result = self._run(None, fetches, feed_dict, options_ptr, File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1220, in _run results = self._do_run(handle, final_targets, final_fetches, File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1400, in _do_run return self._do_call(_run_fn, feeds, fetches, targets, options, File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1426, in _do_call raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter tensorflow.python.framework.errors_impl.UnknownError: Graph execution error: Detected at node 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' defined at (most recent call last): Node: 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' Detected at node 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' defined at (most recent call last): Node: 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' 2 root error(s) found. (0) UNKNOWN: JIT compilation failed. [[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]] [[import/loss/output/_21]] (1) UNKNOWN: JIT compilation failed. [[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]] 0 successful operations. 0 derived errors ignored. Original stack trace for 'import/bert/embeddings/LayerNorm/moments/SquaredDifference': 🔴unet: FAILED: MIGraphX is not within tolerance - check verbose output🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output🔴mask-rcnn: FAILED: MIGraphX is not within tolerance - check verbose output |
Spec here
https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
Useful resources:
https://towardsdatascience.com/transformers-explained-visually-part-2-how-it-works-step-by-step-b49fa4a64f34/
https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853/