forked from modelscope/ms-swift
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Sampler(Vanilla sampler) (modelscope#2905)
- Loading branch information
1 parent
0c97fb4
commit e9f4f9f
Showing
23 changed files
with
1,076 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# 采样 | ||
|
||
采样是SWIFT新支持的重要能力之一,这部分可以理解为`test-time compute`的落地实现。同时,该能力对RFT(强化微调)的实现也至关重要。 | ||
|
||
## 能力介绍 | ||
|
||
SWIFT的sample能力可以使用下面的例子进行: | ||
```shell | ||
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine pt --num_return_sequences 5 --dataset AI-ModelScope/alpaca-gpt4-data-zh#5 | ||
``` | ||
在当前文件夹的`sample_output`目录下,会生成以时间戳为文件名的jsonl文件,该文件应该包含25行,每一行都是一个完整`messages`格式的数据。 | ||
|
||
采样的参数列表请参考[这里](命令行参数.md)。 | ||
|
||
## 环境准备 | ||
|
||
```shell | ||
pip install ms-swift[llm] -U | ||
``` | ||
|
||
或从源代码安装: | ||
|
||
```shell | ||
git clone https://github.com/modelscope/ms-swift.git | ||
cd ms-swift | ||
pip install -e '.[llm]' | ||
``` | ||
|
||
## 使用PRM和ORM进行结果过滤 | ||
|
||
采样重要的能力就是对过程和结果进行监督,这可以通过设置额外参数来支持。 | ||
|
||
```shell | ||
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine lmdeploy --num_return_sequences 5 --n_best_to_keep 2 --dataset tastelikefeet/competition_math#5 --prm_model AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft --orm_model math | ||
``` | ||
|
||
在当前文件夹的`sample_output`目录下,会生成以时间戳为文件名的jsonl文件,该文件**至多包含**10行,每一行都是一个完整`messages`格式的数据。 | ||
> 之所以至多包含10行,是因为虽然设置了共处理5个数据,每个数据保留2个(n_best_to_keep),但是orm可能会校验失败,失败数据不会保留到文件中。 | ||
> 另外,增加了--prm_model或--orm_model后文件格式有所不同,包含了rejected_response key,内容来自于prm评分最低的行。 | ||
## 自定义PRM或ORM | ||
|
||
PRM和ORM的自定义可以在plugin中按照现有代码增加一个新的实现。例如: | ||
```python | ||
class CustomPRM: | ||
|
||
# 构造需要是无参的 | ||
def __init__(self): | ||
# init here | ||
pass | ||
|
||
@torch.inference_mode() | ||
def infer(self, infer_requests: List[InferRequest], **kwargs) -> List[ChatCompletionResponse]: | ||
... | ||
|
||
|
||
prms = {'custom': CustomPRM} | ||
``` | ||
|
||
之后在命令行中使用`--prm_model custom`即可。 | ||
|
||
## 实际例子 | ||
|
||
请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/scripts/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。 | ||
|
||
> 注意:该脚本的实际效果和模型、数据、RM的质量强相关,因此仅作为样例出现,用户请自行修改该脚本并训练自己的RM和generator模型。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Sampling | ||
|
||
Sampling is one of the newly supported key capabilities of SWIFT. This feature can be understood as the practical implementation of `test-time compute`. Additionally, this capability is crucial for the implementation of RFT (Reinforcement Fine-Tuning). | ||
|
||
## Capability Introduction | ||
|
||
The sampling capability of SWIFT can be demonstrated with the following example: | ||
|
||
```shell | ||
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine pt --num_return_sequences 5 --dataset AI-ModelScope/alpaca-gpt4-data-zh#5 | ||
``` | ||
|
||
A `jsonl` file with a timestamp as the filename will be generated in the `sample_output` directory of the current folder. This file should contain 25 lines, each representing a complete `messages` format data. | ||
|
||
For a list of sampling parameters, please refer to [here](Command-line-parameters.md). | ||
|
||
## Environment Setup | ||
|
||
```shell | ||
pip install ms-swift[llm] -U | ||
``` | ||
|
||
Or install swift from source: | ||
|
||
```shell | ||
git clone https://github.com/modelscope/ms-swift.git | ||
cd ms-swift | ||
pip install -e '.[llm]' | ||
``` | ||
|
||
## Using PRM and ORM for Result Filtering | ||
|
||
An important capability of sampling is supervising the process and results, which can be supported by setting additional parameters. | ||
|
||
```shell | ||
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine lmdeploy --num_return_sequences 5 --n_best_to_keep 2 --dataset tastelikefeet/competition_math#5 --prm_model AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft --orm_model math | ||
``` | ||
|
||
A `jsonl` file with a timestamp as the filename will be generated in the `sample_output` directory of the current folder. This file **will contain at most** 10 lines, each representing a complete `messages` format data. | ||
> The reason it contains at most 10 lines is that although 5 data points are processed in total, and 2 are kept for each data point (`n_best_to_keep`), ORM may fail some validations, and failed data will not be retained in the file. | ||
> Additionally, after adding `--prm_model` or `--orm_model`, the file format is slightly different and includes a `rejected_response` key, which contains the responses with the lowest PRM scores. | ||
## Customizing PRM or ORM | ||
|
||
PRM and ORM can be customized by adding a new implementation in the plugin according to the existing code. For example: | ||
|
||
```python | ||
class CustomPRM: | ||
|
||
# The constructor should be parameterless | ||
def __init__(self): | ||
# Initialize here | ||
pass | ||
|
||
@torch.inference_mode() | ||
def infer(self, infer_requests: List[InferRequest], **kwargs) -> List[ChatCompletionResponse]: | ||
... | ||
|
||
|
||
prms = {'custom': CustomPRM} | ||
``` | ||
|
||
Afterward, use `--prm_model custom` in the command line. | ||
|
||
## Practical Example | ||
|
||
Please refer to the [Reinforcement Fine-Tuning Script](https://github.com/modelscope/ms-swift/tree/main/scripts/rft.py). This script provides a practical example of using sampling for reinforcement fine-tuning. | ||
|
||
> **Note:** The actual effectiveness of this script is strongly related to the quality of the model, data, and RM. Therefore, it is presented only as an example. Users should modify this script and train their own RM and generator models accordingly. |
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.