-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DCU] Llama a8w8 inference performance optimization #8800
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #8800 +/- ##
===========================================
+ Coverage 55.37% 55.58% +0.21%
===========================================
Files 631 630 -1
Lines 99707 98382 -1325
===========================================
- Hits 55211 54687 -524
+ Misses 44496 43695 -801 ☔ View full report in Codecov by Sentry. |
@@ -1384,7 +1392,10 @@ def compute_mmha(self, qkv_out, caches, attn_mask, seq_lens, rotary_embs, rotary | |||
)[0] | |||
|
|||
def compute_out_linear(self, fmha_out, i): | |||
out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i], False, True) | |||
if paddle.is_compiled_with_rocm(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把rocm需要不转置的理由在PR描述里说下吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加说明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
Models
Description
Optimize the inference performance in Llama a8w8 case.
On the DCU platform, performance of rocblas gemm under different transpositions is NT > NN > TN. Due to the matmul of paddle, NT cannot be triggered in this scenario, so a suboptimal solution is chosen, which is NN.