Skip to content

Commit fb03941

Browse files
authored
feat: Add multi-turn SFT support (#195)
1 parent 4f245a3 commit fb03941

File tree

7 files changed

+553
-0
lines changed

7 files changed

+553
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Create a simple multi-turn dataset for testing
16+
"""
17+
18+
import os
19+
import pandas as pd
20+
import argparse
21+
22+
23+
def main():
24+
parser = argparse.ArgumentParser()
25+
parser.add_argument('--local_dir', default='~/data/multiturn')
26+
parser.add_argument('--hdfs_dir', default=None)
27+
args = parser.parse_args()
28+
29+
# Create example conversations
30+
conversations = []
31+
32+
# Conversation 1
33+
conversations.append({
34+
"messages": [{
35+
"role": "system",
36+
"content": "You are a helpful assistant."
37+
}, {
38+
"role": "user",
39+
"content": "What is the capital of France?"
40+
}, {
41+
"role": "assistant",
42+
"content": "The capital of France is Paris."
43+
}, {
44+
"role": "user",
45+
"content": "And what about Germany?"
46+
}, {
47+
"role": "assistant",
48+
"content": "The capital of Germany is Berlin."
49+
}]
50+
})
51+
52+
# Conversation 2
53+
conversations.append({
54+
"messages": [{
55+
"role": "system",
56+
"content": "You are a helpful assistant."
57+
}, {
58+
"role": "user",
59+
"content": "Can you explain quantum computing?"
60+
}, {
61+
"role":
62+
"assistant",
63+
"content":
64+
"Quantum computing is a type of computing that uses quantum-mechanical phenomena, such as superposition and entanglement, to perform operations on data."
65+
}, {
66+
"role": "user",
67+
"content": "How is it different from classical computing?"
68+
}, {
69+
"role":
70+
"assistant",
71+
"content":
72+
"Classical computing uses bits that are either 0 or 1, while quantum computing uses quantum bits or qubits that can exist in multiple states simultaneously due to superposition."
73+
}]
74+
})
75+
76+
# Conversation 3
77+
conversations.append({
78+
"messages": [{
79+
"role": "system",
80+
"content": "You are a helpful assistant."
81+
}, {
82+
"role": "user",
83+
"content": "Write a simple Python function to calculate factorial."
84+
}, {
85+
"role":
86+
"assistant",
87+
"content":
88+
"```python\ndef factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n-1)\n```\n\nThis is a recursive function to calculate the factorial of a number."
89+
}, {
90+
"role": "user",
91+
"content": "Can you make it iterative instead?"
92+
}, {
93+
"role":
94+
"assistant",
95+
"content":
96+
"```python\ndef factorial(n):\n result = 1\n for i in range(1, n+1):\n result *= i\n return result\n```\n\nThis is an iterative version of the factorial function."
97+
}]
98+
})
99+
100+
# Create train and test datasets
101+
train_data = conversations[:2] # First 2 conversations for training
102+
test_data = conversations[2:] # Last conversation for testing
103+
104+
# Create output directory
105+
local_dir = os.path.expanduser(args.local_dir)
106+
os.makedirs(local_dir, exist_ok=True)
107+
108+
# Save to parquet files
109+
train_df = pd.DataFrame(train_data)
110+
test_df = pd.DataFrame(test_data)
111+
112+
train_df.to_parquet(os.path.join(local_dir, 'train.parquet'))
113+
test_df.to_parquet(os.path.join(local_dir, 'test.parquet'))
114+
115+
# Handle HDFS if specified
116+
if args.hdfs_dir is not None:
117+
try:
118+
from verl.utils.hdfs_io import copy, makedirs
119+
makedirs(args.hdfs_dir)
120+
copy(src=local_dir, dst=args.hdfs_dir)
121+
except ImportError:
122+
print("Warning: HDFS support not available. Skipping HDFS copy.")
123+
124+
# Print statistics
125+
print(f"Train dataset size: {len(train_df)}")
126+
print(f"Test dataset size: {len(test_df)}")
127+
print(f"Data saved to {local_dir}")
128+
129+
130+
if __name__ == '__main__':
131+
main()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
set -x
3+
4+
if [ "$#" -lt 2 ]; then
5+
echo "Usage: run_qwen_05_sp2.sh <nproc_per_node> <save_path> [other_configs...]"
6+
exit 1
7+
fi
8+
9+
nproc_per_node=$1
10+
save_path=$2
11+
12+
# Shift the arguments so $@ refers to the rest
13+
shift 2
14+
15+
torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
16+
-m verl.trainer.fsdp_sft_trainer \
17+
data.train_files=$HOME/data/multiturn/train.parquet \
18+
data.val_files=$HOME/data/multiturn/test.parquet \
19+
data.multiturn.enable=true \
20+
data.multiturn.messages_key=messages \
21+
data.micro_batch_size=4 \
22+
model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
23+
trainer.default_local_dir=$save_path \
24+
trainer.project_name=multiturn-sft \
25+
trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \
26+
trainer.logger=['console'] \
27+
trainer.total_training_steps=1 \
28+
trainer.default_hdfs_dir=null $@ \
29+
ulysses_sequence_parallel_size=2 \
30+
use_remove_padding=true

tests/sft/run_sft_multiturn.sh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
set -x
3+
4+
if [ "$#" -lt 2 ]; then
5+
echo "Usage: run_qwen_05_sp2.sh <nproc_per_node> <save_path> [other_configs...]"
6+
exit 1
7+
fi
8+
9+
nproc_per_node=$1
10+
save_path=$2
11+
12+
# Shift the arguments so $@ refers to the rest
13+
shift 2
14+
15+
torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
16+
-m verl.trainer.fsdp_sft_trainer \
17+
data.train_files=$HOME/data/multiturn/train.parquet \
18+
data.val_files=$HOME/data/multiturn/test.parquet \
19+
data.multiturn.enable=true \
20+
data.multiturn.messages_key=messages \
21+
data.micro_batch_size=4 \
22+
model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
23+
trainer.default_local_dir=$save_path \
24+
trainer.project_name=multiturn-sft \
25+
trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \
26+
trainer.logger=['console'] \
27+
trainer.total_training_steps=1 \
28+
trainer.default_hdfs_dir=null $@ \
29+
ulysses_sequence_parallel_size=2 \
30+
use_remove_padding=true
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Test the MultiTurnSFTDataset implementation
16+
"""
17+
import os
18+
import pandas as pd
19+
import torch
20+
from transformers import AutoTokenizer
21+
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
22+
23+
24+
def test_multiturn_sft_dataset():
25+
print("Starting test...")
26+
# Create a temporary parquet file with test data
27+
test_data = {
28+
'messages': [[{
29+
"role": "system",
30+
"content": "You are a helpful assistant."
31+
}, {
32+
"role": "user",
33+
"content": "What is 2+2?"
34+
}, {
35+
"role": "assistant",
36+
"content": "2+2 equals 4."
37+
}, {
38+
"role": "user",
39+
"content": "And what is 4+4?"
40+
}, {
41+
"role": "assistant",
42+
"content": "4+4 equals 8."
43+
}],
44+
[{
45+
"role": "system",
46+
"content": "You are a helpful assistant."
47+
}, {
48+
"role": "user",
49+
"content": "Tell me a joke."
50+
}, {
51+
"role": "assistant",
52+
"content": "Why did the chicken cross the road?"
53+
}, {
54+
"role": "user",
55+
"content": "Why?"
56+
}, {
57+
"role": "assistant",
58+
"content": "To get to the other side!"
59+
}]]
60+
}
61+
62+
# Create test directory if it doesn't exist
63+
os.makedirs('test_data', exist_ok=True)
64+
test_file = 'test_data/test.parquet'
65+
66+
# Save test data to parquet
67+
df = pd.DataFrame(test_data)
68+
df.to_parquet(test_file)
69+
70+
# Initialize tokenizer and dataset
71+
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-7B-Instruct')
72+
config = {'max_length': 512, 'truncation': 'error', 'multiturn': {'messages_key': 'messages'}}
73+
dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)
74+
75+
# Test 1: Dataset Length
76+
assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}"
77+
78+
# Get items for testing
79+
item0 = dataset[0] # Math conversation
80+
item1 = dataset[1] # Joke conversation
81+
82+
# Test 2: Required Keys and Types
83+
required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask']
84+
for key in required_keys:
85+
assert key in item0, f"Missing key {key} in dataset item"
86+
assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}"
87+
assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}"
88+
89+
# Test 3: Shape Consistency
90+
assert item0['loss_mask'].shape == item0['input_ids'].shape, \
91+
"Loss mask shape doesn't match input_ids shape"
92+
assert item0['attention_mask'].shape == item0['input_ids'].shape, \
93+
"Attention mask shape doesn't match input_ids shape"
94+
assert item0['position_ids'].shape == item0['input_ids'].shape, \
95+
"Position IDs shape doesn't match input_ids shape"
96+
97+
# Test 4: Loss Mask Pattern - Math Conversation
98+
loss_mask0 = item0['loss_mask']
99+
input_ids0 = item0['input_ids']
100+
101+
# Find assistant response positions
102+
assistant_positions0 = torch.where(loss_mask0 == 1)[0]
103+
assert len(assistant_positions0) > 0, "No assistant positions found in loss mask"
104+
105+
# Decode and verify assistant responses
106+
assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1])
107+
print(f"Math conversation assistant text: {assistant_text0}")
108+
assert "2+2 equals 4" in assistant_text0, "First assistant response not found"
109+
assert "4+4 equals 8" in assistant_text0, "Second assistant response not found"
110+
111+
# Test 5: Loss Mask Pattern - Joke Conversation
112+
loss_mask1 = item1['loss_mask']
113+
input_ids1 = item1['input_ids']
114+
115+
# Find assistant response positions
116+
assistant_positions1 = torch.where(loss_mask1 == 1)[0]
117+
assert len(assistant_positions1) > 0, "No assistant positions found in loss mask"
118+
119+
# Decode and verify assistant responses
120+
assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1])
121+
print(f"Joke conversation assistant text: {assistant_text1}")
122+
assert "chicken cross the road" in assistant_text1, "First assistant response not found"
123+
assert "other side" in assistant_text1, "Second assistant response not found"
124+
125+
# Test 6: Attention Mask Pattern
126+
attention_mask0 = item0['attention_mask']
127+
sequence_length = torch.sum(attention_mask0)
128+
assert sequence_length > 0, "No tokens marked as attended in attention mask"
129+
assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern"
130+
if sequence_length < len(attention_mask0):
131+
assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked"
132+
133+
# Test 7: Position IDs Pattern
134+
position_ids0 = item0['position_ids']
135+
assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), \
136+
"Position IDs not sequential for non-padded tokens"
137+
if sequence_length < len(position_ids0):
138+
assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero"
139+
140+
# Test 8: Verify loss mask for assistant responses
141+
# Get the full conversation text
142+
full_text = tokenizer.decode(input_ids0)
143+
print(f"\nFull conversation text:\n{full_text}")
144+
145+
# Get the assistant responses
146+
assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1])
147+
print(f"\nAssistant responses (from loss mask):\n{assistant_text}")
148+
149+
# Verify that loss mask is set for all assistant responses
150+
for msg in test_data['messages'][0]: # First conversation
151+
if msg['role'] == 'assistant':
152+
# The content should appear in the masked text
153+
assert msg['content'] in assistant_text, \
154+
f"Assistant message '{msg['content']}' not found in masked text"
155+
156+
# The content should NOT appear in the non-masked text
157+
non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0])
158+
assert msg['content'] not in non_assistant_text, \
159+
f"Assistant message '{msg['content']}' found in non-assistant text"
160+
161+
# Test 9: Verify non-assistant parts have loss_mask=0
162+
# Get non-assistant text
163+
non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0])
164+
print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}")
165+
166+
# Verify that system and user messages are in the non-assistant text
167+
for msg in test_data['messages'][0]: # First conversation
168+
if msg['role'] in ['system', 'user']:
169+
assert msg['content'] in non_assistant_text, \
170+
f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text"
171+
172+
# And verify they're NOT in the assistant text
173+
assert msg['content'] not in assistant_text, \
174+
f"{msg['role'].title()} message '{msg['content']}' found in assistant text"
175+
176+
# Test 10: Verify padding behavior
177+
padding_config = {'max_length': 1024, 'truncation': 'error', 'multiturn': {'messages_key': 'messages'}}
178+
small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config)
179+
padded_item = small_dataset[0]
180+
181+
# Get actual sequence length (before padding)
182+
actual_length = torch.sum(padded_item['attention_mask'])
183+
184+
# Verify padding tokens
185+
assert torch.all(padded_item['input_ids'][actual_length:] == tokenizer.pad_token_id), \
186+
"Padding tokens not set correctly"
187+
assert torch.all(padded_item['attention_mask'][actual_length:] == 0), \
188+
"Attention mask not set correctly for padding"
189+
assert torch.all(padded_item['loss_mask'][actual_length:] == 0), \
190+
"Loss mask not set correctly for padding"
191+
192+
print("All tests passed!")
193+
print("Starting test...")

0 commit comments

Comments
 (0)