Skip to content

Support Mha dpas shader #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

Kotomi-Du
Copy link
Collaborator

Currently GEMM only works for QK_QKV mode which is enough for dev need

@@ -19,8 +19,11 @@ enum class GemmType
// q + kv
GemmType_QK_Q_KV,
GemmType_SV_S_KV,
//dpas
GemmType_QK_QKV_DPAS,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of new GemmType type I recommend to add cmd argument "--use_dpas" to GemmCmDispatcher::cm_params_t class. This way we will avoid a lot of if conditions like for example else if (type_ == GemmType::GemmType_QK_QKV || type_ == GemmType::GemmType_QK_QKV_DPAS ) and we can use have very little if conditions with "--use_dpas" in kernel/build options selection mechanism

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of new GemmType type I recommend to add cmd argument "--use_dpas" to GemmCmDispatcher::cm_params_t class. This way we will avoid a lot of if conditions like for example else if (type_ == GemmType::GemmType_QK_QKV || type_ == GemmType::GemmType_QK_QKV_DPAS ) and we can use have very little if conditions with "--use_dpas" in kernel/build options selection mechanism

Based on your suggestion, I may need to change a logic for selecting CM shader. Could you help check whether it may work with the change on the right side.

image

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that looks good!

gws_z = get_batch() * get_channels() * cm_params_.slice_k;
if(params_.use_dpas)
{
gws_x = 8;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gws sizes hardcoded? what if M/K/N dimensions changes via cmd line args?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Prithviraj-R Could you update the calculation here once you figure it out? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants