Skip to content

Commit

Permalink
removing padding
Browse files Browse the repository at this point in the history
  • Loading branch information
openllmai0 committed Jul 8, 2024
1 parent 2eb5f88 commit bfeb067
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 10 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ To use OpenRLHF, first ``git clone`` it and launch the docker container (**Recom

```bash
git clone https://github.com/openllmai/OpenRLHF.git
# You can also `git checkout` the latest stable release version.

# If you need to use vLLM, please build a Docker image to avoid dependency issues (Optional)
docker build -t nvcr.io/nvidia/pytorch:24.02-py3 ./OpenRLHF/dockerfile
Expand Down Expand Up @@ -144,7 +145,7 @@ Then you can use the startup scripts we provide in the [examples/scripts](./exam

```bash
deepspeed ./train_sft.py \
--max_len 2048 \
--max_len 4096 \
--dataset Open-Orca/OpenOrca \
--input_key question \
--output_key response \
Expand All @@ -170,13 +171,15 @@ deepspeed ./train_sft.py \
# --input_key {JSON Key}
# --tokenizer_chat_template {HF Chat Template}

# SFT samples packing
# --packing_samples

# Can also be used for continued pre-training
# --pretrain_mode

```

> [!NOTE]
> OpenRLHF SFT supports `--packing_samples` [using `--flash_attn`](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)
> OpenRLHF SFT supports `--packing_samples` [based on `--flash_attn`](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)

### Reward Model Training
Expand Down
8 changes: 6 additions & 2 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ OpenRLHF 是一个基于 Ray、DeepSpeed 和 HF Transformers 构建的高性能

```bash
git clone https://github.com/openllmai/OpenRLHF.git
# 也可以 `git checkout` 最近的稳定 release 版本.

# 如果需要使用 vLLM,请构建 Docker 镜像以避免依赖问题(可选)
docker build -t nvcr.io/nvidia/pytorch:24.02-py3 ./OpenRLHF/dockerfile
Expand Down Expand Up @@ -148,7 +149,7 @@ OpenRLHF 的模型检查点完全兼容 HuggingFace 模型。您可以使用 `--

```bash
deepspeed ./train_sft.py \
--max_len 2048 \
--max_len 4096 \
--dataset Open-Orca/OpenOrca \
--input_key question \
--output_key response \
Expand All @@ -174,12 +175,15 @@ deepspeed ./train_sft.py \
# --input_key {JSON Key}
# --tokenizer_chat_template {HF Chat Template}

# 支持 SFT 样本 packing
# --packing_samples

# 也可用于 continued pre-training
# --pretrain_mode
```

> [!NOTE]
> OpenRLHF SFT 支持 `--packing_samples` [基于 `--flash_attn`](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)
> OpenRLHF SFT 支持的 `--packing_samples` [基于 `--flash_attn`](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)
### Reward Model Training
```bash
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/train_iterative_dpo_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ while (($iter < $TRAINING_ITERS)); do
--apply_chat_template \
--temperature 1.0 \
--tp_size 4 \
--best_of_n 16 \
--best_of_n 8 \
--enable_prefix_caching \
--max_num_seqs 128 \
--max_num_seqs 64 \
--iter $iter \
--rollout_batch_size $ROLLOUT_BATCH_SIZE \
--output_path $GENERATE_OUTPUT
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/train_rejection_sampling_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ while (($iter < $TRAINING_ITERS)); do
--bf16 \
--max_new_tokens 2048 \
--prompt_max_len 2048 \
--max_samples 128 \
--dataset OpenLLMAI/prompt-collection-v0.1 \
--input_key context_messages \
--apply_chat_template \
Expand All @@ -44,7 +43,7 @@ while (($iter < $TRAINING_ITERS)); do
--best_of_n 4 \
--enable_prefix_caching \
--tp_size 4 \
--micro_batch_size 64 \
--max_num_seqs 64 \
--iter $iter \
--rollout_batch_size $ROLLOUT_BATCH_SIZE \
--output_path $GENERATE_OUTPUT
Expand Down
21 changes: 20 additions & 1 deletion openrlhf/models/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel
from transformers.deepspeed import HfDeepSpeedConfig

from openrlhf.datasets.utils import zero_pad_sequences
from .utils import log_probs_from_logits


Expand Down Expand Up @@ -161,7 +162,25 @@ def process_sequences(self, sequences: torch.Tensor, input_len, eos_token_id, pa
action_mask = state_seq.ne(eos_token_id) & state_seq.ne(pad_token_id)
action_mask[:, 0] = 1

return sequences, attention_mask, action_mask
# removing padding
sequences2 = []
attention_mask2 = []
action_mask2 = []

for seq, att_mask, act_mask in zip(sequences, attention_mask, action_mask):
right_pad = (1 - act_mask.long()).sum()
right_pad = None if right_pad == 0 else -right_pad
left_pad = att_mask.long().argmax()

sequences2.append(seq[left_pad:right_pad])
attention_mask2.append(att_mask[left_pad:right_pad])
action_mask2.append(act_mask[:right_pad])

return (
zero_pad_sequences(sequences2, "left"),
zero_pad_sequences(attention_mask2, "left"),
zero_pad_sequences(action_mask2, "left"),
)

def forward(
self,
Expand Down

0 comments on commit bfeb067

Please sign in to comment.