Skip to content

Conversation

@xiaoxiaohehe001
Copy link
Contributor

@xiaoxiaohehe001 xiaoxiaohehe001 commented Aug 17, 2023

PR types

Others

PR changes

Others

Description

Add bias input of mmha and simplify mmha.
关联pr #55344
Pcard-71502

@paddle-bot
Copy link

paddle-bot bot commented Aug 17, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

out->set_dims(out_dims);
out->set_dtype(x.dtype());
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

量化的完整需求是
输入为int32/float16/32
输出为int8/float/16/32

我看这里考虑了输出是int8的情况 但是没考虑输入是int32的情况是么


namespace phi {
namespace fusion {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥要把头文件干掉呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要添加多余的头文件,防止被其他调用

Args:
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim].
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
bias (Tensor, optional): The bias tensor. Its shape is [3, num_head, head_dim].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也需要加一下compte_dtype的参数说明~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done~

self.rotary_emb_dims = 0
self.use_neox_rotary_style = False

self.compute_dtype = "default"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测要不加一下输入为Int32的情形

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测有int32的情况

seq_len (int, optional): The seq_len, used to get input length. Default 1.
rotary_emb_dims (int, optional): The rotary_emb_dims. Default 1.
use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False.
compute_dtype (string): A compute dtype, used to represent the input data type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_dtype的为啥不能根据输入tensor的类型判断呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptq 情况下,输入 x 有可能是int32,如果根据 cache_kv dtype 判断,后续 cache_kv 量化支持还需要修改。

Copy link
Contributor

@vivienfanghuagood vivienfanghuagood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for API change

Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for new args

@heavengate heavengate merged commit 636dc2f into PaddlePaddle:develop Aug 25, 2023
BeingGod pushed a commit to BeingGod/Paddle that referenced this pull request Sep 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants