Skip to content

Add Attention Microsoft Contrib Operator #3816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 46 commits into
base: develop
Choose a base branch
from

Conversation

@TedThemistokleous TedThemistokleous added roadmap Tasks to finish for a release onnxruntime PR changes interaction between MIGraphX and Onnxruntime Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase labels Feb 14, 2025
@TedThemistokleous TedThemistokleous self-assigned this Feb 14, 2025
@TedThemistokleous TedThemistokleous linked an issue Feb 20, 2025 that may be closed by this pull request
3 tasks
@causten causten added the high priority A PR with high priority for review and merging. label Feb 21, 2025
breaking this up to smaller pieces for optional args as I populate the proper vector inputs before tying things to the calculation and creation of multi head attention layers
need to finish with other input args and check infered and parsed attributes.
split this up to clean up the handle_inputs call and seperate errors/state when we aquire attributes
…nput correctly.

Need to fill in parser piece but this checks and ensures we're working with the proper batch size for our calculations within the attention head.

Debug still needs to be removedb but this ensures we're seeing the proper amount of heads that are batched correctly
…r now.

add some sort of tracked state for padding modes of the mask_index for input linear layer masking prior to attention head splits.
Too much Cpp too little python
Tests more representative of customer workloads and models we see in the wild. Need to finish these to complete parseer tests.

Will add tests for other inputs and error cases later
Give an explanation to how things are parsed in as the input sizes of masks, inputs, weights as well as attributes can change how certain infered values in the parser can be calculated. This is due to how the spec specifices how inputs will be handled on parse.
clean up debug from input_linear_to_qkv and have input be put in via vector if instructions.
@TedThemistokleous TedThemistokleous force-pushed the add_attention_contrib_op branch from 32f91e4 to a7f8a3b Compare March 1, 2025 19:56
Ted Themistokleous added 2 commits March 2, 2025 06:52
@ahsan-ca ahsan-ca removed their assignment Apr 9, 2025
TedThemistokleous and others added 2 commits April 23, 2025 11:13
…cat and slice from Q,K,V mats

This should theoretically reduce the overhead from slices and fuse things easier. Previous changes caused our parser to generate a large amount of slice/dot operations for each attention head.

By adding the head dimension instead of slicing out the pieces and handling those instructions we reduce parser time but also simplify how things should get fused as well as the code for this.

This also removes the need to add splits for the mask and modifies on the input mat before the softmax.

Still seeing a large slowdown still. Vs Previous parser getting 142ms (225 QPS) on target model vs the previus parser's 129 (250 QPS). It appears we're not fusing the attention pieces correctly still
@TedThemistokleous
Copy link
Collaborator Author

Adding change that gets rid of the concat for now based on @pfultz2 's recommendation. Seeing a speedup with this when using MLIR fusion on an attention block. Mirroed a model workload with just attention in a gen_onnx.py test and we're getting a significant speedup.

Before with no MLIR flags

Summary:
gpu::code_object::mul_add_reduce_max_sub_exp_reduce_sum_div_kernel: 3.86515ms / 1 = 3.86515ms, 51%
gpu::gemm: 2.92227ms / 3 = 0.97409ms, 39%
gpu::code_object::mlir_reshape_dot: 0.344482ms / 1 = 0.344482ms, 5%
gpu::code_object::mlir_reshape_reshape_transpose_dot: 0.327408ms / 1 = 0.327408ms, 5%
gpu::code_object::contiguous_kernel: 0.115116ms / 1 = 0.115116ms, 2%
gpu::code_object::not_convert_mul_kernel: 0.0205068ms / 1 = 0.0205068ms, 1%
load: 0.00884556ms / 7 = 0.00126365ms, 1%
slice: 0.0055792ms / 3 = 0.00185973ms, 1%
multibroadcast: 0.00461838ms / 3 = 0.00153946ms, 1%
@param: 0.00413152ms / 5 = 0.000826304ms, 1%
reshape_lazy: 0.00160806ms / 1 = 0.00160806ms, 1%
unsqueeze: 0.00158582ms / 1 = 0.00158582ms, 1%
check_context::migraphx::gpu::context: 0.00141454ms / 1 = 0.00141454ms, 1%
broadcast: 0.00133182ms / 1 = 0.00133182ms, 1%
hip::hip_allocate_memory: 0.00126186ms / 1 = 0.00126186ms, 1%

Batch size: 1
Rate: 133.76 inferences/sec
Total time: 7.47609ms (Min: 7.46932ms, Max: 7.48673ms, Mean: 7.47635ms, Median: 7.4761ms)
Percentiles (90%, 95%, 99%): (7.48157ms, 7.48357ms, 7.48636ms)
Total instructions time: 7.62531ms
Overhead time: 0.00697326ms, -0.149226ms
Overhead: 0%, -2%
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx

After with MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention,~dot

Summary:
gpu::gemm: 2.96798ms / 3 = 0.989326ms, 83%
gpu::code_object::mlir_reshape_reshape_transpose_dot_mul_add_softmax_reshape_dot: 0.482526ms / 1 = 0.482526ms, 14%
gpu::code_object::contiguous_kernel: 0.114383ms / 1 = 0.114383ms, 4%
gpu::code_object::not_convert_mul_kernel: 0.0203474ms / 1 = 0.0203474ms, 1%
load: 0.00602034ms / 5 = 0.00120407ms, 1%
slice: 0.00533918ms / 3 = 0.00177973ms, 1%
multibroadcast: 0.00512556ms / 3 = 0.00170852ms, 1%
@param: 0.00405146ms / 5 = 0.000810292ms, 1%
unsqueeze: 0.00152746ms / 1 = 0.00152746ms, 1%
broadcast: 0.0014327ms / 1 = 0.0014327ms, 1%
check_context::migraphx::gpu::context: 0.00138322ms / 1 = 0.00138322ms, 1%
reshape_lazy: 0.00134576ms / 1 = 0.00134576ms, 1%
hip::hip_allocate_memory: 0.00116304ms / 1 = 0.00116304ms, 1%

Batch size: 1
Rate: 290.792 inferences/sec
Total time: 3.43888ms (Min: 3.42825ms, Max: 3.58141ms, Mean: 3.46037ms, Median: 3.43728ms)
Percentiles (90%, 95%, 99%): (3.55396ms, 3.56477ms, 3.57875ms)
Total instructions time: 3.61262ms
Overhead time: 0.0063449ms, -0.173741ms
Overhead: 0%, -5%
MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention,~dot \ 
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx

Vs prior to removing the concat

Flags

Summary:
gpu::gemm: 2.92583ms / 3 = 0.975278ms, 51%
gpu::code_object::mlir_slice_reshape_transpose_slice_squeeze_dot_mul_add_softmax_slice_dot: 2.65099ms / 16 = 0.165687ms, 46%
gpu::code_object::concat_kernel: 0.162542ms / 1 = 0.162542ms, 3%
slice: 0.0289778ms / 19 = 0.00152515ms, 1%
multibroadcast: 0.0286749ms / 19 = 0.0015092ms, 1%
load: 0.0240347ms / 20 = 0.00120174ms, 1%
gpu::code_object::not_convert_mul_kernel: 0.020387ms / 1 = 0.020387ms, 1%
@param: 0.00453062ms / 5 = 0.000906124ms, 1%
unsqueeze: 0.00160364ms / 1 = 0.00160364ms, 1%
check_context::migraphx::gpu::context: 0.0013431ms / 1 = 0.0013431ms, 1%
hip::hip_allocate_memory: 0.00128176ms / 1 = 0.00128176ms, 1%

Batch size: 1
Rate: 189.493 inferences/sec
Total time: 5.27725ms (Min: 5.24866ms, Max: 5.3047ms, Mean: 5.27715ms, Median: 5.27773ms)
Percentiles (90%, 95%, 99%): (5.29325ms, 5.29882ms, 5.30175ms)
Total instructions time: 5.8502ms
Overhead time: 0.0229228ms, -0.572951ms
Overhead: 0%, -11%
MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention,~dot \ 
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx

No flags

Summary:
gpu::gemm: 2.92451ms / 3 = 0.974837ms, 61%
gpu::code_object::mul_add_reduce_max_sub_exp_reduce_sum_div_kernel: 0.704009ms / 16 = 0.0440006ms, 15%
gpu::code_object::mlir_slice_reshape_transpose_slice_squeeze_dot: 0.502644ms / 16 = 0.0314153ms, 11%
gpu::code_object::mlir_slice_dot: 0.418183ms / 16 = 0.0261364ms, 9%
gpu::code_object::concat_kernel: 0.162208ms / 1 = 0.162208ms, 4%
load: 0.0630131ms / 52 = 0.00121179ms, 2%
slice: 0.028744ms / 19 = 0.00151284ms, 1%
multibroadcast: 0.0275485ms / 19 = 0.00144992ms, 1%
gpu::code_object::not_convert_mul_kernel: 0.0197071ms / 1 = 0.0197071ms, 1%
@param: 0.0045218ms / 5 = 0.00090436ms, 1%
unsqueeze: 0.00178614ms / 1 = 0.00178614ms, 1%
check_context::migraphx::gpu::context: 0.00130386ms / 1 = 0.00130386ms, 1%
hip::hip_allocate_memory: 0.0012166ms / 1 = 0.0012166ms, 1%

Batch size: 1
Rate: 257.44 inferences/sec
Total time: 3.88439ms (Min: 3.87337ms, Max: 3.90875ms, Mean: 3.8846ms, Median: 3.88467ms)
Percentiles (90%, 95%, 99%): (3.89034ms, 3.8924ms, 3.89919ms)
Total instructions time: 4.8594ms
Overhead time: 0.0366287ms, -0.975003ms
Overhead: 1%, -25%
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx

Adding in code from Ahsan to handle in the qkv_hidden_size
This allows us to also fuse teh slice block for each matrix instead and simplifies the logic for the input linear layer
@TedThemistokleous
Copy link
Collaborator Author

Found a way to get the slice fused into the attention head as well if I handle the weight dot product before the slice block. Simplied the logic too. Getting a slight speedup per attention head by doing this, leaving only the GEMM input being the issue now. I'm able to now let us use the attention,dot flags for Mlir without penalty. I get about a 0.01ms speed up per attetion head doing this and as a result the input broadcast also becomes fused.

Summary:
gpu::code_object::mlir_broadcast_dot: 2.9135ms / 1 = 2.9135ms, 83%
gpu::code_object::mlir_slice_reshape_slice_reshape_transpose_dot_mul_add_softmax_slice_reshape_dot: 0.481105ms / 1 = 0.481105ms, 14%
gpu::code_object::contiguous_kernel: 0.115444ms / 1 = 0.115444ms, 4%
gpu::code_object::not_convert_mul_kernel: 0.0186374ms / 1 = 0.0186374ms, 1%
@param: 0.00381764ms / 5 = 0.000763528ms, 1%
load: 0.00313322ms / 3 = 0.00104441ms, 1%
reshape_lazy: 0.00130248ms / 1 = 0.00130248ms, 1%
check_context::migraphx::gpu::context: 0.00127416ms / 1 = 0.00127416ms, 1%
broadcast: 0.00114032ms / 1 = 0.00114032ms, 1%
hip::hip_allocate_memory: 0.00094864ms / 1 = 0.00094864ms, 1%

Batch size: 1
Rate: 291.754 inferences/sec
Total time: 3.42755ms (Min: 3.41509ms, Max: 3.51073ms, Mean: 3.4358ms, Median: 3.42659ms)
Percentiles (90%, 95%, 99%): (3.47342ms, 3.48969ms, 3.50983ms)
Total instructions time: 3.54031ms
Overhead time: 0.00367538ms, -0.112759ms
Overhead: 0%, -3%
MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention,dot \ 
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx

@TedThemistokleous
Copy link
Collaborator Author

Need to get changes sorted as attention isn't fusing anymore when using this PR in testing - #3993

We want this to be supported as it reduces the amount of larger dot/GEMMs and reduces our GEMM times from 98+ms -> ~75ms

Meaning this gives us about a 20-25% boost should attention fuse correctly.

@TedThemistokleous
Copy link
Collaborator Author

Got fusion sorted for this and benchmarked things using Customer script. Saw a 21% speedup using Paul's change.

Added past/present input outputs and omnidirectional leaving rotary input encoding out for now to just get this in. I can open a issue for this.

All I have left to do is parser/verify test. I've reduced the size for the attention_double_head for the verify test simplicity (batch 2, seq 4, hidden 4, etc) The math should workout to be the same once this is scaled up.

Use this test data for verification (WIP)
simplify tests to test various cases may or manynot import more OnnxRT test cases for this.

Getting fails on this right now for gold data, Seeing inconsistent output.
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
f6a702
Rate old
fce74e
Diff Compare
torchvision-resnet50 64 3,227.29 3,237.29 -0.31%
torchvision-resnet50_fp16 64 6,887.57 6,875.20 0.18%
torchvision-densenet121 32 2,447.27 2,444.78 0.10%
torchvision-densenet121_fp16 32 4,205.70 4,185.93 0.47%
torchvision-inceptionv3 32 1,618.97 1,617.74 0.08%
torchvision-inceptionv3_fp16 32 2,705.95 2,707.57 -0.06%
cadene-inceptionv4 16 755.98 755.79 0.03%
cadene-resnext64x4 16 815.39 813.93 0.18%
slim-mobilenet 64 7,440.21 7,439.91 0.00%
slim-nasnetalarge 64 209.91 208.65 0.60%
slim-resnet50v2 64 3,337.12 3,332.41 0.14%
bert-mrpc-onnx 8 1,142.75 1,142.60 0.01%
bert-mrpc-tf 1 457.00 462.31 -1.15%
pytorch-examples-wlang-gru 1 343.47 343.13 0.10%
pytorch-examples-wlang-lstm 1 478.85 486.93 -1.66%
torchvision-resnet50_1 1 819.45 803.37 2.00%
cadene-dpn92_1 1 433.77 414.94 4.54% 🔆
cadene-resnext101_1 1 392.07 392.67 -0.15%
onnx-taau-downsample 1 395.20 395.19 0.00%
dlrm-criteoterabyte 1 32.19 32.26 -0.22%
dlrm-criteoterabyte_fp16 1 51.12 51.26 -0.27%
agentmodel 1 10,106.85 10,477.47 -3.54% 🔴
unet_fp16 2 59.36 59.43 -0.11%
resnet50v1_fp16 1 1,085.50 1,041.11 4.26% 🔆
resnet50v1_int8 1 1,065.22 1,068.70 -0.33%
bert_base_cased_fp16 64 1,169.78 1,169.98 -0.02%
bert_large_uncased_fp16 32 356.42 356.28 0.04%
bert_large_fp16 1 200.15 200.05 0.05%
distilgpt2_fp16 16 2,229.61 2,229.84 -0.01%
yolov5s 1 537.94 545.53 -1.39%
tinyllama 1 43.69 43.60 0.20%
vicuna-fastchat 1 44.77 44.86 -0.20%
whisper-tiny-encoder 1 419.55 418.31 0.30%
whisper-tiny-decoder 1 413.86 402.69 2.77%
llama2_7b 1 19.01 19.05 -0.20%
qwen1.5-7b 1 23.45 23.43 0.06%
phi3-3.8b 1 26.53 26.54 -0.00%
mask-rcnn 1 12.72 12.81 -0.65%
llama3-8b 1 21.67 21.66 0.06%
whisper-large-encoder 1 10.18 10.18 0.04%
whisper-large-decoder 1 101.86 100.97 0.88%
mistral-7b 1 23.69 23.68 0.05%
FLUX.1-schnell 1 912.28 776.17 17.54% 🔆
nan nan nan nan nan%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

❌bert-mrpc-tf: ERROR - check error output2025-05-28 04:26:17.660151: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748424383.360979 184202 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 62973 MB memory: -> device: 0, name: AMD Instinct MI250X/MI250, pci bus id: 0000:32:00.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748424384.268987 184202 mlir_graph_optimization_pass.cc:401] MLIR V1 optimization pass is not enabled
2025-05-28 04:26:34.521466: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-28 04:26:34.521545: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-28 04:26:34.521591: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-28 04:26:34.522017: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-28 04:26:34.522045: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-28 04:26:34.522085: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-28 04:26:34.522125: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-05-28 04:26:34.522164: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
2025-05-28 04:26:34.523161: E tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc:228] INTERNAL: Generating device code failed.
2025-05-28 04:26:34.524697: W tensorflow/core/framework/op_kernel.cc:1829] UNKNOWN: JIT compilation failed.
2025-05-28 04:26:34.524719: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
2025-05-28 04:26:34.524731: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
[[import/loss/output/_21]]
2025-05-28 04:26:34.524748: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11217777527359497193
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1407, in _do_call
return fn(*args)
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1390, in _run_fn
return self._call_tf_sessionrun(options, feed_dict, fetch_list,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1483, in _call_tf_sessionrun
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.UnknownError: 2 root error(s) found.
(0) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
[[import/loss/output/_21]]
(1) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
0 successful operations.
0 derived errors ignored.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 324, in main
y_out = sess.run(y, feed_dict=tf_dict)
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 977, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1220, in _run
results = self._do_run(handle, final_targets, final_fetches,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1400, in _do_run
return self._do_call(_run_fn, feeds, fetches, targets, options,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1426, in _do_call
raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter
tensorflow.python.framework.errors_impl.UnknownError: Graph execution error:

Detected at node 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' defined at (most recent call last):
Node: 'import/bert/embeddings/LayerNorm/moments/SquaredDifference'
Detected at node 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' defined at (most recent call last):
Node: 'import/bert/embeddings/LayerNorm/moments/SquaredDifference'
2 root error(s) found.
(0) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
[[import/loss/output/_21]]
(1) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
0 successful operations.
0 derived errors ignored.

Original stack trace for 'import/bert/embeddings/LayerNorm/moments/SquaredDifference':


     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

🔴unet: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

     ✅ llama2_7b: PASSED: MIGraphX meets tolerance

     ✅ qwen1.5-7b: PASSED: MIGraphX meets tolerance

     ✅ phi3-3.8b: PASSED: MIGraphX meets tolerance

🔴mask-rcnn: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ llama3-8b: PASSED: MIGraphX meets tolerance

     ✅ whisper-large-decoder: PASSED: MIGraphX meets tolerance

     ✅ mistral-7b: PASSED: MIGraphX meets tolerance

     ✅ FLUX.1-schnell: PASSED: MIGraphX meets tolerance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority A PR with high priority for review and merging. Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase onnxruntime PR changes interaction between MIGraphX and Onnxruntime roadmap Tasks to finish for a release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Parser for Attention Contrib OP
5 participants