Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM] Support block_attention/cachekv quant for llama #7649

Merged
merged 15 commits into from
Jan 10, 2024

Conversation

RichardWooSJTU
Copy link
Contributor

@RichardWooSJTU RichardWooSJTU commented Dec 14, 2023

PR types

New features

PR changes

Others

Description

  1. support block attention, default disabled, enabling by setting block_attn
  2. support cachekv quant, defalt disabled, enabling static quant by setting use_cachekv_int8 to static, enabling dynamic quant by setting use_cachekv_int8 to dynamic,

Copy link

paddle-bot bot commented Dec 14, 2023

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Dec 14, 2023

CLA assistant check
All committers have signed the CLA.

Copy link

codecov bot commented Dec 14, 2023

Codecov Report

Attention: 422 lines in your changes are missing coverage. Please review.

Comparison is base (dab175b) 57.12% compared to head (8b91dc8) 56.95%.
Report is 5 commits behind head on develop.

Files Patch % Lines
...dlenlp/experimental/transformers/llama/modeling.py 0.00% 209 Missing ⚠️
...enlp/experimental/transformers/generation_utils.py 0.00% 100 Missing ⚠️
...erimental/transformers/fused_transformer_layers.py 0.00% 99 Missing ⚠️
paddlenlp/experimental/model_utils.py 12.50% 14 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #7649      +/-   ##
===========================================
- Coverage    57.12%   56.95%   -0.17%     
===========================================
  Files          587      587              
  Lines        88190    88626     +436     
===========================================
+ Hits         50376    50479     +103     
- Misses       37814    38147     +333     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

llm/export.sh Outdated



python -m paddle.distributed.launch \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个脚本可以先放到llama 模型目录下

@@ -26,6 +26,7 @@
import paddle
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如沟通先移动到llama目录下,后续等所有的模型都覆盖了,再迁移

llm/read_res.py Outdated
@@ -0,0 +1,49 @@
import paddle
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

# --batch_size 2 \
# --inference_model \
# --quant_type ${quant_type} \
# --block_attn \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的脚本可以迁移到inference.md 文档里面去了。

--block_attn \
--inference_model \
--use_cachekv_int8 static

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

temperature,
model_kwargs,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充上PaddleNLP CI 和 Paddle CI

}
}

// 根据上一步计算出的可以复原的query_id进行状态恢复
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换成英文注释吧~

llm/.gitignore Outdated
@@ -1,3 +1,6 @@

max_len.txt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是你本地的测试环境下的配置文件,建议回滚一下这个文件。

llm/benchmark.sh Outdated
@@ -27,10 +27,10 @@ export FLAGS_cache_inference_while_scope=1
python predictor.py \
--model_name_or_path ./llama7b-inference_model_fp16 \
--dtype float16 \
--src_length 300 \
--max_length 100 \
--src_length ${total_len} \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 从语义上面来讲,total_len应该是 src_length + max_length,所以这两个参数的变量命令是不是与--name 对应一致呢?
  2. src_length 和 max_length 设置一下默认值: src_length=${src_length:-300}

llm/export.sh Outdated
--src_length ${total_len} \
--block_attn \
--quant_type ${quant_type} \
--use_cachekv_int8 static
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边的建议是将这个脚本迁移到 ./inference.md 文档里面去,新开一个 cachekv 的 section 来描述这个。

llm/predictor.py Outdated
@@ -703,6 +723,526 @@ def _infer(self, inputs: dict[str, paddle.Tensor]):
return None


class DygraphBlockInferencePredictor(BasePredictor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个模块应该是需要继承:InferencePredictorMixin 吧,这个都是用来处理推理模型的。

Comment on lines +760 to +766
self.pre_caches = [
paddle.zeros(
[config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim],
dtype=self.dtype,
)
for _ in range(2 * self.num_layers)
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边有测试过 pre_cache + cache_kv-int8 的组合吗?如果有的话,能够添加一下对应的单测呢?

# --batch_size 2 \
# --inference_model \
# --quant_type ${quant_type} \
# --block_attn \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的脚本可以迁移到inference.md 文档里面去了。

Comment on lines 410 to 411
print("scale_type: ", scale_type)
print("key_template: ", key_template)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是一个基础类,所以这里的 print 应该是需要删除的。

Comment on lines +746 to +751
def post_process(self, **kwargs):
time_step = kwargs.get("time_step", None)
multi_block_output = kwargs.get("multi_block_output", None)
cum_offsets = kwargs.get("cum_offsets", None)
seq_lens = kwargs.get("seq_lens", None)
input_ids = kwargs.get("input_ids", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的代码都是要写单测的。

Comment on lines 839 to 840
# print("out_linear_out", out_linear_out)
# exit(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm

else:
precache_kv_spec = None
use_cachekv_int8 = config.get("use_cachekv_int8", "None")
print("use_cachekv_int8", use_cachekv_int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他的 print 我就不一一说了。

@wawltor wawltor merged commit c5d8d5b into PaddlePaddle:develop Jan 10, 2024
7 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants