-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLM] Support block_attention/cachekv quant for llama #7649
[LLM] Support block_attention/cachekv quant for llama #7649
Conversation
Thanks for your contribution! |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## develop #7649 +/- ##
===========================================
- Coverage 57.12% 56.95% -0.17%
===========================================
Files 587 587
Lines 88190 88626 +436
===========================================
+ Hits 50376 50479 +103
- Misses 37814 38147 +333 ☔ View full report in Codecov by Sentry. |
c27663e
to
c43c61a
Compare
llm/export.sh
Outdated
|
||
|
||
|
||
python -m paddle.distributed.launch \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个脚本可以先放到llama 模型目录下
@@ -26,6 +26,7 @@ | |||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如沟通先移动到llama目录下,后续等所有的模型都覆盖了,再迁移
llm/read_res.py
Outdated
@@ -0,0 +1,49 @@ | |||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
llm/run_dygraph.sh
Outdated
# --batch_size 2 \ | ||
# --inference_model \ | ||
# --quant_type ${quant_type} \ | ||
# --block_attn \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的脚本可以迁移到inference.md 文档里面去了。
llm/run_static.sh
Outdated
--block_attn \ | ||
--inference_model \ | ||
--use_cachekv_int8 static | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
temperature, | ||
model_kwargs, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充上PaddleNLP CI 和 Paddle CI
} | ||
} | ||
|
||
// 根据上一步计算出的可以复原的query_id进行状态恢复 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
换成英文注释吧~
llm/.gitignore
Outdated
@@ -1,3 +1,6 @@ | |||
|
|||
max_len.txt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是你本地的测试环境下的配置文件,建议回滚一下这个文件。
llm/benchmark.sh
Outdated
@@ -27,10 +27,10 @@ export FLAGS_cache_inference_while_scope=1 | |||
python predictor.py \ | |||
--model_name_or_path ./llama7b-inference_model_fp16 \ | |||
--dtype float16 \ | |||
--src_length 300 \ | |||
--max_length 100 \ | |||
--src_length ${total_len} \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 从语义上面来讲,total_len应该是 src_length + max_length,所以这两个参数的变量命令是不是与--name 对应一致呢?
- src_length 和 max_length 设置一下默认值: src_length=${src_length:-300}
llm/export.sh
Outdated
--src_length ${total_len} \ | ||
--block_attn \ | ||
--quant_type ${quant_type} \ | ||
--use_cachekv_int8 static |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我这边的建议是将这个脚本迁移到 ./inference.md 文档里面去,新开一个 cachekv 的 section 来描述这个。
llm/predictor.py
Outdated
@@ -703,6 +723,526 @@ def _infer(self, inputs: dict[str, paddle.Tensor]): | |||
return None | |||
|
|||
|
|||
class DygraphBlockInferencePredictor(BasePredictor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个模块应该是需要继承:InferencePredictorMixin 吧,这个都是用来处理推理模型的。
self.pre_caches = [ | ||
paddle.zeros( | ||
[config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim], | ||
dtype=self.dtype, | ||
) | ||
for _ in range(2 * self.num_layers) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这边有测试过 pre_cache + cache_kv-int8 的组合吗?如果有的话,能够添加一下对应的单测呢?
llm/run_dygraph.sh
Outdated
# --batch_size 2 \ | ||
# --inference_model \ | ||
# --quant_type ${quant_type} \ | ||
# --block_attn \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的脚本可以迁移到inference.md 文档里面去了。
print("scale_type: ", scale_type) | ||
print("key_template: ", key_template) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是一个基础类,所以这里的 print 应该是需要删除的。
def post_process(self, **kwargs): | ||
time_step = kwargs.get("time_step", None) | ||
multi_block_output = kwargs.get("multi_block_output", None) | ||
cum_offsets = kwargs.get("cum_offsets", None) | ||
seq_lens = kwargs.get("seq_lens", None) | ||
input_ids = kwargs.get("input_ids", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的代码都是要写单测的。
# print("out_linear_out", out_linear_out) | ||
# exit(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rm
else: | ||
precache_kv_spec = None | ||
use_cachekv_int8 = config.get("use_cachekv_int8", "None") | ||
print("use_cachekv_int8", use_cachekv_int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其他的 print 我就不一一说了。
PR types
New features
PR changes
Others
Description
block_attn
use_cachekv_int8
tostatic
, enabling dynamic quant by settinguse_cachekv_int8
todynamic
,