Skip to content

Commit

Permalink
Support Sampler(Vanilla sampler) (modelscope#2905)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Jan 13, 2025
1 parent 0c97fb4 commit e9f4f9f
Show file tree
Hide file tree
Showing 23 changed files with 1,076 additions and 5 deletions.
26 changes: 26 additions & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,32 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数)
- hub_private_repo: 是否是private repo,默认为False
- commit_message: 提交信息,默认为'update files'

### 采样参数

- prm_model: 过程奖励模型的类型,可以是模型id(以pt方式拉起),或者plugin中定义的prm key(自定义推理过程)
- orm_model: 结果奖励模型的类型,通常是通配符或测试用例等,一般定义在plugin中

- sampler_type:采样类型,目前支持sample(do_sample方式),未来会支持mcts和dvts
- sampler_engine:支持`pt`, `lmdeploy`, `vllm`, `no`,默认为`pt`,采样模型的推理引擎
- output_dir:输出目录,默认为`sample_output`
- output_file:输出文件名称,默认为`None`使用时间戳作为文件名。传入时不需要传入目录,仅支持jsonl格式
- override_exist_file:如`output_file`存在,是否覆盖
- num_sampling_per_gpu_batch_size:每次采样的batch_size
- num_sampling_per_gpu_batches:共采样多少batch
- n_best_to_keep:返回多少最佳sequences
- data_range:本采样处理数据集的分片。传入格式为`2 3`,代表数据集分为3份处理(这意味着通常有三个`swift sample`在并行处理),本实例正在处理第3个分片

- temperature:在这里默认为1.0
- prm_threshold:PRM阈值,低于该阈值的结果会被过滤掉,默认值为`0`
- easy_query_threshold:单个query的所有采样中,ORM评估如果正确,大于该比例的query会被丢弃,防止过于简单的query出现在结果中,默认为`None`,代表不过滤

- engine_kwargs:传入sampler_engine的额外参数,以json string传入,例如`{"cache_max_entry_count":0.7}`

- num_return_sequences:采样返回的原始sequence数量。默认为64,本参数对`sample`采样有效
- cache_files:为避免同时加载prm和generator造成显存OOM,可以分两步进行采样,第一步将prm和orm置为`None`,则所有结果都会输出到文件中,第二次运行采样将sampler_engine置为`no`并传入`--cache_files`为上次采样的输出文件,则会使用上次输出的结果进行prm和orm评估并输出最终结果。

> 注意:使用cache_files时,--dataset仍然需要传入,这是因为cache_files的id是由原始数据计算的md5,需要把两部分信息结合使用。

## 特定模型参数
特定模型参数可以通过`--model_kwargs`或者环境变量进行设置,例如: `--model_kwargs '{"fps_max_frames": 12}'`或者`FPS_MAX_FRAMES=12`
Expand Down
66 changes: 66 additions & 0 deletions docs/source/Instruction/采样.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# 采样

采样是SWIFT新支持的重要能力之一,这部分可以理解为`test-time compute`的落地实现。同时,该能力对RFT(强化微调)的实现也至关重要。

## 能力介绍

SWIFT的sample能力可以使用下面的例子进行:
```shell
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine pt --num_return_sequences 5 --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
```
在当前文件夹的`sample_output`目录下,会生成以时间戳为文件名的jsonl文件,该文件应该包含25行,每一行都是一个完整`messages`格式的数据。

采样的参数列表请参考[这里](命令行参数.md)

## 环境准备

```shell
pip install ms-swift[llm] -U
```

或从源代码安装:

```shell
git clone https://github.com/modelscope/ms-swift.git
cd ms-swift
pip install -e '.[llm]'
```

## 使用PRM和ORM进行结果过滤

采样重要的能力就是对过程和结果进行监督,这可以通过设置额外参数来支持。

```shell
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine lmdeploy --num_return_sequences 5 --n_best_to_keep 2 --dataset tastelikefeet/competition_math#5 --prm_model AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft --orm_model math
```

在当前文件夹的`sample_output`目录下,会生成以时间戳为文件名的jsonl文件,该文件**至多包含**10行,每一行都是一个完整`messages`格式的数据。
> 之所以至多包含10行,是因为虽然设置了共处理5个数据,每个数据保留2个(n_best_to_keep),但是orm可能会校验失败,失败数据不会保留到文件中。
> 另外,增加了--prm_model或--orm_model后文件格式有所不同,包含了rejected_response key,内容来自于prm评分最低的行。
## 自定义PRM或ORM

PRM和ORM的自定义可以在plugin中按照现有代码增加一个新的实现。例如:
```python
class CustomPRM:

# 构造需要是无参的
def __init__(self):
# init here
pass

@torch.inference_mode()
def infer(self, infer_requests: List[InferRequest], **kwargs) -> List[ChatCompletionResponse]:
...


prms = {'custom': CustomPRM}
```

之后在命令行中使用`--prm_model custom`即可。

## 实际例子

请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/scripts/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。

> 注意:该脚本的实际效果和模型、数据、RM的质量强相关,因此仅作为样例出现,用户请自行修改该脚本并训练自己的RM和generator模型。
26 changes: 26 additions & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,32 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum
- hub_private_repo: Whether it is a private repo, default is False.
- commit_message: Commit message, default is 'update files'.

### Sampling Parameters

- prm_model: The type of process reward model. It can be a model ID (triggered using `pt`) or a `prm` key defined in a plugin (for custom inference processes).
- orm_model: The type of outcome reward model, typically a wildcard or test case, usually defined in a plugin.

- sampler_type: The type of sampling. Currently supports `sample` (using `do_sample` method). Future support will include `mcts` and `dvts`.
- sampler_engine: Supports `pt`, `lmdeploy`, `vllm`, `no`. Defaults to `pt`. Specifies the inference engine for the sampling model.
- output_dir: The output directory. Defaults to `sample_output`.
- output_file: The name of the output file. Defaults to `None`, which uses a timestamp as the filename. When provided, only the filename should be passed without the directory, and only JSONL format is supported.
- override_exist_file: Whether to overwrite if `output_file` already exists.
- num_sampling_per_gpu_batch_size: The batch size for each sampling operation.
- num_sampling_per_gpu_batches: The total number of batches to sample.
- n_best_to_keep: The number of best sequences to return.
- data_range: The partition of the dataset being processed for this sampling operation. The format should be `2 3`, meaning the dataset is divided into 3 parts, and this instance is processing the 3rd partition (this implies that typically three `swift sample` processes are running in parallel).

- temperature: Defaults to `1.0`.
- prm_threshold: The PRM threshold. Results below this value will be filtered out. The default value is `0`.
- easy_query_threshold: For each query, if the ORM evaluation is correct for more than this proportion of all samples, the query will be discarded to prevent overly simple queries from appearing in the results. Defaults to `None`, meaning no filtering is applied.

- engine_kwargs: Additional parameters for the `sampler_engine`, passed as a JSON string, for example, `{"cache_max_entry_count":0.7}`.

- num_return_sequences: The number of original sequences returned by sampling. Defaults to `64`. This parameter is effective for `sample` sampling.
- cache_files: To avoid loading both `prm` and `generator` simultaneously and causing GPU memory OOM, sampling can be done in two steps. In the first step, set `prm` and `orm` to `None`, and all results will be output to a file. In the second run, set `sampler_engine` to `no` and pass `--cache_files` with the output file from the first sampling. This will use the results from the first run for `prm` and `orm` evaluation and output the final results.

> Note: When using `cache_files`, the `--dataset` still needs to be provided because the ID for `cache_files` is calculated using the MD5 of the original data. Both pieces of information need to be used together.
## Specific Model Arguments

Specific model arguments can be set using `--model_kwargs` or environment variables, for example: `--model_kwargs '{"fps_max_frames": 12}'` or `FPS_MAX_FRAMES=12`.
Expand Down
69 changes: 69 additions & 0 deletions docs/source_en/Instruction/Sample.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Sampling

Sampling is one of the newly supported key capabilities of SWIFT. This feature can be understood as the practical implementation of `test-time compute`. Additionally, this capability is crucial for the implementation of RFT (Reinforcement Fine-Tuning).

## Capability Introduction

The sampling capability of SWIFT can be demonstrated with the following example:

```shell
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine pt --num_return_sequences 5 --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
```

A `jsonl` file with a timestamp as the filename will be generated in the `sample_output` directory of the current folder. This file should contain 25 lines, each representing a complete `messages` format data.

For a list of sampling parameters, please refer to [here](Command-line-parameters.md).

## Environment Setup

```shell
pip install ms-swift[llm] -U
```

Or install swift from source:

```shell
git clone https://github.com/modelscope/ms-swift.git
cd ms-swift
pip install -e '.[llm]'
```

## Using PRM and ORM for Result Filtering

An important capability of sampling is supervising the process and results, which can be supported by setting additional parameters.

```shell
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine lmdeploy --num_return_sequences 5 --n_best_to_keep 2 --dataset tastelikefeet/competition_math#5 --prm_model AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft --orm_model math
```

A `jsonl` file with a timestamp as the filename will be generated in the `sample_output` directory of the current folder. This file **will contain at most** 10 lines, each representing a complete `messages` format data.
> The reason it contains at most 10 lines is that although 5 data points are processed in total, and 2 are kept for each data point (`n_best_to_keep`), ORM may fail some validations, and failed data will not be retained in the file.
> Additionally, after adding `--prm_model` or `--orm_model`, the file format is slightly different and includes a `rejected_response` key, which contains the responses with the lowest PRM scores.
## Customizing PRM or ORM

PRM and ORM can be customized by adding a new implementation in the plugin according to the existing code. For example:

```python
class CustomPRM:

# The constructor should be parameterless
def __init__(self):
# Initialize here
pass

@torch.inference_mode()
def infer(self, infer_requests: List[InferRequest], **kwargs) -> List[ChatCompletionResponse]:
...


prms = {'custom': CustomPRM}
```

Afterward, use `--prm_model custom` in the command line.

## Practical Example

Please refer to the [Reinforcement Fine-Tuning Script](https://github.com/modelscope/ms-swift/tree/main/scripts/rft.py). This script provides a practical example of using sampling for reinforcement fine-tuning.

> **Note:** The actual effectiveness of this script is strongly related to the quality of the model, data, and RM. Therefore, it is presented only as an example. Users should modify this script and train their own RM and generator models accordingly.
1 change: 1 addition & 0 deletions scripts/rft/math.json

Large diffs are not rendered by default.

Loading

0 comments on commit e9f4f9f

Please sign in to comment.