Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FasterCodeGen/FasterGPTJ #3017

Merged
merged 28 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 222 additions & 57 deletions examples/code_generation/codegen/README.md
Original file line number Diff line number Diff line change
@@ -1,74 +1,219 @@
# CodeGen: A Conversational Paradigm for Program Synthesis
# 代码生成:写代码的AI助理

## 模型简介
**目录**
- [代码生成](#代码生成)
- [简介](#简介)
- [特色](#特色)
- [效果展示](#效果展示)
- [开箱即用](#开箱即用)
- [支持单条、批量预测](#支持单条批量预测)
- [可配置参数说明](#可配置参数说明)
- [训练定制](#训练定制)
- [环境依赖](#环境依赖)
- [代码结构说明](#代码结构说明)
- [数据准备](#数据准备)
- [从本地文件创建数据集](#从本地文件创建数据集)
- [Github Copilot插件配置](#GithubCopilot插件配置)
- [插件环境依赖](#插件环境依赖)
- [启动服务](#启动服务)
- [配置参数](#配置参数说明)
- [测试服务](#测试服务)
- [配置插件](#配置插件)
- [注意事项](#注意事项)
- [TaskFlow调用](#TaskFlow调用)
- [使用案例](#使用案例)
- [模型列表](#模型列表)
- [References](#references)

[CodeGen](https://arxiv.org/pdf/2203.13474.pdf) (A Conversational Paradigm for Program Synthesis)提出了一种通过大型语言模型进行对话式程序生成的方法,将编写规范和程序的过程转换为用户和系统之间的多回合对话。它把程序生成看作一个序列预测问题,用自然语言表达规范,并有条件地对所期望的程序进行抽样。同时,CodeGen(16B)在HumanEval benchmark上已经超过[OpenAI's Codex](https://arxiv.org/pdf/2107.03374.pdf)。

本项目展示如何调用CodeGen来进行代码生成。
## 简介
代码生成是根据编程人员的输入,生成出编程人员想要的代码,能够帮助编程人员甚至独立生成代码,提高编程效率。

## 快速开始

### 特色

本项目是基于预训练语言模型CodeGen的代码生成,具有以下优势:
- **效果领先**。CodeGen(16B)在HumanEval benchmark上评估指标已经超过[OpenAI's Codex](https://arxiv.org/pdf/2107.03374.pdf)。
- **免费的Github Copilot**。支持通过Github Copilot调用该模型,让你免费体验代码AI助理。
- **高性能**。基于[FasterGeneration](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/faster_generation)打造高性能推理,毫秒级响应。具体加速指标可参考[perf](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/faster_generation/README.md)。
- **支持自定义数据集训练**。可增加自己的代码数据加以微调,让其更智能。
- **开箱即用**。本项目提供TaskFlow接口,无需训练,仅需几行代码便可预测。


## 效果展示

## 训练定制

### 环境依赖
- PaddleNLP >= 2.4.0
- PaddlePaddle >= 2.3.1

### 代码结构说明

以下是本项目主要代码结构及说明:

```text
codegen/
├── requirements.txt # 环境依赖
├── codegen_server.py # server启动脚本
├── run_clm.py # 训练评估脚本
├── run_clm.sh # 启动脚本
└── README.md # 说明文档
```

### 数据准备

#### 从本地文件创建数据集

在许多情况,我们需要使用本地数据集来训练我们的代码生成模型,本项目支持使用固定格式本地数据集文件进行训练。

本地数据集文件格式如下:
- train.json/test.json 文件格式:
每行为一个jsonline
```text
{
"code": "from paddlenlp.transformers import CodeGenForCausalLM\n\n\nmodel = CodeGenForCausalLM.from_pretrained('Salesforce/codegen-2B-mono')\n"
}
```

更多数据集读取格式详见[数据集加载](https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_load.html#)和[自定义数据集](https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_self_defined.html)。


### 模型训练
运行如下命令即可在样例训练集上进行finetune,并在样例验证集上进行验证。

```shell
# GPU启动,参数`--gpus`指定训练所用的GPU卡号,可以是单卡,也可以多卡
unset CUDA_VISIBLE_DEVICES

python -m paddle.distributed.launch --gpus 0,1 run_clm.py \
--model_name_or_path Salesforce/codegen-350M-mono \
--block_size 1024 \
--output_dir output \
--train_file train.json \
--validation_file test.json \
--num_train_epochs 5 \
--logging_steps 1 \
--save_steps 10 \
--train_batch_size 2 \
--eval_batch_size 2 \
--learning_rate 1e-4 \
--warmup_proportion 0.1 \
--device gpu
```
使用多卡训练可以指定多个GPU卡号,例如 --gpus "0,1"

关键参数释义如下:
- `gpus` 指示了训练所用的GPU卡号。
- `model_name_or_path` 指示了finetune使用的具体预训练模型,可以是PaddleNLP提供的预训练模型(详见[模型列表](#模型列表)),或者是本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。
- `block_size` 表示训练时候数据被拆分的块数。
- `output_dir` 表示模型的保存路径。
- `train_file` 本地训练数据地址,数据格式必须与`dataset_name`所指数据集格式相同。
- `validation_file` 本地测试数据地址,数据格式必须与`dataset_name`所指数据集格式相同。
- `num_train_epochs` 表示训练轮数。
- `logging_steps` 表示日志打印间隔。
- `save_steps` 表示模型保存及评估间隔。
- `train_batch_size` 表示训练时**每张卡**上的样本数目。
- `eval_batch_size` 表示测试时**每张卡**上的样本数目。
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
- `warmup_propotion` 表示学习率逐渐升高到基础学习率(即上面配置的learning_rate)所需要的迭代数占总步数的比例,最早的使用可以参考[这篇论文](https://arxiv.org/pdf/1706.02677.pdf)。
- `device` 表示使用的设备,从gpu和cpu中选择。

- python >= 3.6
- paddlepaddle >= 2.3.0
- paddlenlp >= 2.3.4
可通过`bash run_clm.sh`启动训练,更多参数详情和参数的默认值请参考`run_clm.py`。

### 代码调用
程序运行时将会自动进行训练和验证,训练过程中会自动保存模型在指定的`save_dir`中。
如:
```text
./output/
│── model_config.json
│── model_state.pdparams
│── tokenizer_config.json
│── special_tokens_map.json
│── added_tokens.json
│── vocab.json
│── merges.txt
└── ...
```

**NOTE:** 如需恢复模型训练,`model_name_or_path`配置本地模型的目录地址即可。

## GithubCopilot插件配置
以下以VS Code的插件为例
### 插件环境依赖
- PaddleNLP >= 2.4.0
- PaddlePaddle >= 2.3.1

其他依赖:`pip install -r requirements.txt`


### 启动服务

```python
import re
import paddle
from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM
python codegen_server.py
```

##### 配置参数说明
在codegen_server.py中配置如下参数:
- `model_name_or_path`:模型名,默认为 "Salesforce/codegen-2B-mono"
- `device`:运行设备,默认为"gpu"
- `temperature`:解码参数temperature,默认为0.5
- `top_k`:解码参数top_k,默认为10
- `top_p`:解码参数top_p,默认为1.0
- `repetition_penalty`:解码重复惩罚项,默认为1.0
- `min_length`:生成的最小长度,默认为0
- `max_length`:生成的最大长度,默认为16
- `decode_strategy`:解码策略,默认为"sampling"
- `load_state_as_np`:以numpy格式加载模型参数,可节省显存,默认为True
- `use_faster`:是否使用Fastergeneration,可加速推理,默认为True
- `use_fp16_decoding`:是否使用fp16推理,可节省显存和加速推理,默认为True

### 测试服务
`pip install --upgrade openai`

```python
import openai
openai.api_key = 'dummy'
openai.api_base = 'http://127.0.0.1:8000/v1'
result = openai.Completion.create(
engine='codegen', prompt='def hello', max_tokens=16, temperature=0.1)
print(result)
'''
<OpenAIObject text_completion id=cmpl-dmhoeHmcw9DJ4NeqOJDQVKv3iivJ0 at 0x7fe7a81d42c0> JSON: {
"id": "cmpl-dmhoeHmcw9DJ4NeqOJDQVKv3iivJ0",
"choices": [
{
"text": "_world():\n print(\"Hello World!\")\n\n\n#",
"index": 0,
"finish_reason": "stop",
"logprobs": null,
}
],
"usage": {
"completion_tokens": null,
"prompt_tokens": null,
"total_tokens": null
}
}
'''

# The supported models are shown in the following table
model_name = 'Salesforce/codegen-350M-mono'
# Init tokenizer
tokenizer = CodeGenTokenizer.from_pretrained(model_name)
# Init model
model = CodeGenForCausalLM.from_pretrained(model_name)
inputs = tokenizer(["def hello_world():"])
inputs = {k: paddle.to_tensor(v) for (k, v) in inputs.items()}
# Generate
output, score = model.generate(inputs['input_ids'],
max_length=128,
decode_strategy='sampling',
top_k=5,
repetition_penalty=1.1,
temperature=0.6)
# Decode the result
print(
re.split(
"\nclass|\ndef|\n#|\n@|\nprint|\nif",
tokenizer.decode(output[0],
skip_special_tokens=True,
spaces_between_special_tokens=False))[0].rstrip())
```

其中参数释义如下:
- `max_length` 解码的最大长度,默认128。
- `decode_strategy` 解码的策略,默认sampling。
- `top_k` 解码参数top_k,默认5。
- `repetition_penalty` 解码重复惩罚系数,默认1.1。
- `temperature` 解码参数temperature,默认0.6。
### 配置插件
打开用户设置([settings.json](https://code.visualstudio.com/docs/getstarted/settings#_settings-file-locations)),增加一行配置
```json
"github.copilot.advanced": {
"debug.overrideEngine": "codegen",
"debug.testOverrideProxyUrl": "http://127.0.0.1:8978",
"debug.overrideProxyUrl": "http://127.0.0.1:8978"
},
```

模型列表
| 模型名称 | 说明 |
| :--------------------------------- | -------------------------------- |
| Salesforce/codegen-350M-mono | 基于Python数据集BIGPYTHON训练 |
| Salesforce/codegen-2B-mono | 基于Python数据集BIGPYTHON训练 |
| Salesforce/codegen-6B-mono | 基于Python数据集BIGPYTHON训练 |
| Salesforce/codegen-16B-mono | 基于Python数据集BIGPYTHON训练 |
| Salesforce/codegen-350M-nl | 基于自然语言数据集THEPILE训练 |
| Salesforce/codegen-2B-nl | 基于自然语言数据集THEPILE训练 |
| Salesforce/codegen-6B-nl | 基于自然语言数据集THEPILE训练 |
| Salesforce/codegen-16B-nl | 基于自然语言数据集THEPILE训练 |
| Salesforce/codegen-350M-multi | 基于多编程语言数据集BIGQUERY训练 |
| Salesforce/codegen-2B-multi | 基于多编程语言数据集BIGQUERY训练 |
| Salesforce/codegen-6B-multi | 基于多编程语言数据集BIGQUERY训练 |
| Salesforce/codegen-16B-multi | 基于多编程语言数据集BIGQUERY训练 |
接下来就可以愉快地使用了😊。
#### 注意事项
- 如果使用FasterGeneration,需要设置[codegen_server.py](#配置参数说明)中`use_faster=True`,第一次推理会涉及到编译,会耗费一些时间。FasterGeneration的环境依赖参考[这里](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/ops/README.md#%E4%BD%BF%E7%94%A8%E7%8E%AF%E5%A2%83%E8%AF%B4%E6%98%8E)。
- 如果要使用自己训练好的模型,可以设置[codegen_server.py](#配置参数说明)中`model_name_or_path`为本地模型路径。

### TaskFlow调用
## TaskFlow调用
参考[TaskFlow文档](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/model_zoo/taskflow.md)

## 使用案例
Expand Down Expand Up @@ -157,4 +302,24 @@ def hello_world():

hello_world()
```
其它更多趣味性的生成欢迎大家体验,同时也欢迎大家来开发代码补全的插件。

## 模型列表
模型列表
| 模型名称 | 说明 |
| :--------------------------------- | -------------------------------- |
| Salesforce/codegen-350M-mono | 基于Python数据集BIGPYTHON训练 |
| Salesforce/codegen-2B-mono | 基于Python数据集BIGPYTHON训练 |
| Salesforce/codegen-6B-mono | 基于Python数据集BIGPYTHON训练 |
| Salesforce/codegen-16B-mono | 基于Python数据集BIGPYTHON训练 |
| Salesforce/codegen-350M-nl | 基于自然语言数据集THEPILE训练 |
| Salesforce/codegen-2B-nl | 基于自然语言数据集THEPILE训练 |
| Salesforce/codegen-6B-nl | 基于自然语言数据集THEPILE训练 |
| Salesforce/codegen-16B-nl | 基于自然语言数据集THEPILE训练 |
| Salesforce/codegen-350M-multi | 基于多编程语言数据集BIGQUERY训练 |
| Salesforce/codegen-2B-multi | 基于多编程语言数据集BIGQUERY训练 |
| Salesforce/codegen-6B-multi | 基于多编程语言数据集BIGQUERY训练 |
| Salesforce/codegen-16B-multi | 基于多编程语言数据集BIGQUERY训练 |

## References
- Nijkamp, Erik, et al. "A conversational paradigm for program synthesis." arXiv preprint arXiv:2203.13474 (2022).
- [https://github.com/features/copilot/](https://github.com/features/copilot/)
Loading