-
Notifications
You must be signed in to change notification settings - Fork 4
Support Mha dpas shader #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
tools/cross_runner/src/gemm.h
Outdated
@@ -19,8 +19,11 @@ enum class GemmType | |||
// q + kv | |||
GemmType_QK_Q_KV, | |||
GemmType_SV_S_KV, | |||
//dpas | |||
GemmType_QK_QKV_DPAS, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of new GemmType type I recommend to add cmd argument "--use_dpas"
to GemmCmDispatcher::cm_params_t class. This way we will avoid a lot of if conditions like for example else if (type_ == GemmType::GemmType_QK_QKV || type_ == GemmType::GemmType_QK_QKV_DPAS )
and we can use have very little if conditions with "--use_dpas"
in kernel/build options selection mechanism
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of new GemmType type I recommend to add cmd argument
"--use_dpas"
to GemmCmDispatcher::cm_params_t class. This way we will avoid a lot of if conditions like for exampleelse if (type_ == GemmType::GemmType_QK_QKV || type_ == GemmType::GemmType_QK_QKV_DPAS )
and we can use have very little if conditions with"--use_dpas"
in kernel/build options selection mechanism
Based on your suggestion, I may need to change a logic for selecting CM shader. Could you help check whether it may work with the change on the right side.

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that looks good!
gws_z = get_batch() * get_channels() * cm_params_.slice_k; | ||
if(params_.use_dpas) | ||
{ | ||
gws_x = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gws sizes hardcoded? what if M/K/N dimensions changes via cmd line args?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Prithviraj-R Could you update the calculation here once you figure it out? Thanks!
Currently GEMM only works for QK_QKV mode which is enough for dev need