Skip to content

Commit 7f8ed2f

Browse files
committed
format
Signed-off-by: minmingzhu <minming.zhu@intel.com>
1 parent 6e4fe7f commit 7f8ed2f

File tree

3 files changed

+104
-98
lines changed

3 files changed

+104
-98
lines changed

llm_on_ray/common/dataprocesser/general_processer.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,19 @@ def tokenize_function(self, examples, tokenizer):
116116
new_message = PROMPT_NO_INPUT_FORMAT.format(
117117
instruction=instruction, response=response
118118
)
119-
return tokenizer(new_message, add_special_tokens=False, max_length=self.config.get("max_length"))
119+
return tokenizer(
120+
new_message, add_special_tokens=False, max_length=self.config.get("max_length")
121+
)
120122
else:
121123
new_messages = [
122124
{
123125
"role": "user",
124126
"content": "###Instruction:\n"
125-
+ examples["instruction"] + "\n\n"
126-
+ "###context:\n"
127-
+ examples["context"] + "\n\n",
127+
+ examples["instruction"]
128+
+ "\n\n"
129+
+ "###context:\n"
130+
+ examples["context"]
131+
+ "\n\n",
128132
},
129133
{"role": "assistant", "content": examples["response"] + "\n\n"},
130134
]
@@ -145,9 +149,9 @@ def tokenize_function(self, examples, tokenizer):
145149
new_messages,
146150
tokenize=False,
147151
)
148-
tokenizer = tokenizer(new_tokenizer,
149-
add_special_tokens=False,
150-
max_length=self.config.get("max_length"))
152+
tokenizer = tokenizer(
153+
new_tokenizer, add_special_tokens=False, max_length=self.config.get("max_length")
154+
)
151155
return tokenizer
152156

153157
def prepare(self, tokenizer, dataset):
@@ -184,7 +188,7 @@ def group_texts(examples):
184188
total_length = (total_length // block_size) * block_size
185189
# Split by chunks of max_len.
186190
result = {
187-
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
191+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
188192
for k, t in concatenated_examples.items()
189193
}
190194
result["labels"] = result["input_ids"].copy()

llm_on_ray/finetune/finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def main(external_config=None):
358358

359359
if "xpu" in ipex.__version__:
360360
num_cpus = (
361-
resources_per_worker["CPU"] * num_training_workers + 1
361+
resources_per_worker["CPU"] * num_training_workers + 1
362362
) # additional 1 for head worker
363363
ray.init(num_cpus=num_cpus, runtime_env=runtime_env)
364364
else:

tests/finetune/test_chat_template.py

Lines changed: 91 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -7,133 +7,135 @@
77

88
class TestTokenizeFunction(unittest.TestCase):
99
def setUp(self):
10-
self.tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
10+
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1111
self.config = {
12-
'gpt_base_model': True,
13-
'max_length': 512,
14-
'trust_remote_code': False,
15-
'chat_template': "Below is an instruction that describes a task. Write a response that appropriately "
16-
"completes the request\n {% if messages[0]['role'] == 'system' %}{{ raise_exception("
17-
"'System role not supported') }}{% endif %}{% for message in messages %}{% if (message["
18-
"'role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles "
19-
"must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] "
20-
"== 'user' %}{{ '### Instruction: ' + message['content'] }}{% elif message['role'] == "
21-
"'assistant' %}{{ '### Response: ' + message['content'] }}{% endif %}{% endfor %}{{'### "
22-
"End \n'}}",
12+
"gpt_base_model": True,
13+
"max_length": 512,
14+
"trust_remote_code": False,
15+
"chat_template": "Below is an instruction that describes a task. Write a response that appropriately "
16+
"completes the request\n {% if messages[0]['role'] == 'system' %}{{ raise_exception("
17+
"'System role not supported') }}{% endif %}{% for message in messages %}{% if (message["
18+
"'role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles "
19+
"must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] "
20+
"== 'user' %}{{ '### Instruction: ' + message['content'] }}{% elif message['role'] == "
21+
"'assistant' %}{{ '### Response: ' + message['content'] }}{% endif %}{% endfor %}{{'### "
22+
"End \n'}}",
2323
}
2424
self.processer = GeneralProcesser(self.config)
2525

2626
def test_tokenize_function_with_gpt_model(self):
27-
self.tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-j-6b')
27+
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
2828

29-
examples = \
30-
{
31-
"instruction": "Test instruction",
32-
"response": "Test response",
33-
"context": "Test context",
34-
}
29+
examples = {
30+
"instruction": "Test instruction",
31+
"response": "Test response",
32+
"context": "Test context",
33+
}
3534

3635
# Verify the format of the result
37-
expected_result = 'Below is an instruction that describes a task. Write a response that '\
38-
'appropriately completes the request.\n'\
39-
'\n'\
40-
'### Instruction:\n'\
41-
'Test instruction\n'\
42-
'\n'\
43-
'Input:\n'\
44-
'Test context\n'\
45-
'\n'\
46-
'### Response:\n'\
47-
'Test response\n'\
48-
'\n'\
49-
'### End'
36+
expected_result = (
37+
"Below is an instruction that describes a task. Write a response that "
38+
"appropriately completes the request.\n"
39+
"\n"
40+
"### Instruction:\n"
41+
"Test instruction\n"
42+
"\n"
43+
"Input:\n"
44+
"Test context\n"
45+
"\n"
46+
"### Response:\n"
47+
"Test response\n"
48+
"\n"
49+
"### End"
50+
)
5051

5152
result = self.processer.tokenize_function(examples, self.tokenizer)
52-
self.assertEqual(self.tokenizer.decode(result['input_ids']), expected_result)
53+
self.assertEqual(self.tokenizer.decode(result["input_ids"]), expected_result)
5354

5455
def test_tokenize_function_with_custom_chat_template(self):
55-
examples = \
56-
{
57-
"instruction": "Test instruction",
58-
"response": "Test response",
59-
"context": "Test context",
60-
}
56+
examples = {
57+
"instruction": "Test instruction",
58+
"response": "Test response",
59+
"context": "Test context",
60+
}
6161

6262
# Verify the format of the result
63-
expected_result = '<|im_start|>user\n' \
64-
'###Instruction:\n' \
65-
'Test instruction\n' \
66-
'\n' \
67-
'###context:\n' \
68-
'Test context\n' \
69-
'\n' \
70-
'<|im_end|><|im_start|>assistant\n' \
71-
'Test response\n' \
72-
'\n' \
73-
'<|im_end|>'
63+
expected_result = (
64+
"<|im_start|>user\n"
65+
"###Instruction:\n"
66+
"Test instruction\n"
67+
"\n"
68+
"###context:\n"
69+
"Test context\n"
70+
"\n"
71+
"<|im_end|><|im_start|>assistant\n"
72+
"Test response\n"
73+
"\n"
74+
"<|im_end|>"
75+
)
7476
# Set custom chat template
75-
self.config['custom_chat_template'] = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n'"\
76-
"+ message['content'] + '<|im_end|>'}}{% endfor %}"
77+
self.config["custom_chat_template"] = (
78+
"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n'"
79+
"+ message['content'] + '<|im_end|>'}}{% endfor %}"
80+
)
7781

78-
self.config['gpt_base_model'] = False
82+
self.config["gpt_base_model"] = False
7983
result = self.processer.tokenize_function(examples, self.tokenizer)
80-
self.assertEqual(self.tokenizer.decode(result['input_ids']), expected_result)
84+
self.assertEqual(self.tokenizer.decode(result["input_ids"]), expected_result)
8185

8286
def test_tokenize_function_with_chat_template(self):
83-
examples = \
84-
{
85-
"instruction": "Test instruction",
86-
"response": "Test response",
87-
"context": "Test context",
88-
}
87+
examples = {
88+
"instruction": "Test instruction",
89+
"response": "Test response",
90+
"context": "Test context",
91+
}
8992

9093
# Verify the format of the result
91-
expected_result = 'Below is an instruction that describes a task. Write a response that '\
92-
'appropriately completes the request\n'\
93-
'### Instruction: ###Instruction:\n'\
94-
'Test instruction\n'\
95-
'\n'\
96-
'###context:\n'\
97-
'Test context\n'\
98-
'\n'\
99-
'### Response: Test response\n'\
100-
'\n'\
101-
'### End \n'\
102-
103-
self.config['gpt_base_model'] = False
94+
expected_result = (
95+
"Below is an instruction that describes a task. Write a response that "
96+
"appropriately completes the request\n"
97+
"### Instruction: ###Instruction:\n"
98+
"Test instruction\n"
99+
"\n"
100+
"###context:\n"
101+
"Test context\n"
102+
"\n"
103+
"### Response: Test response\n"
104+
"\n"
105+
"### End \n"
106+
)
107+
self.config["gpt_base_model"] = False
104108
result = self.processer.tokenize_function(examples, self.tokenizer)
105-
self.assertEqual(self.tokenizer.decode(result['input_ids']), expected_result)
109+
self.assertEqual(self.tokenizer.decode(result["input_ids"]), expected_result)
106110

107111
def test_tokenize_function_with_default_chat_template(self):
108-
self.tokenizer = AutoTokenizer.from_pretrained('google/gemma-2b-it')
109-
examples = \
110-
{
111-
"instruction": "Test instruction",
112-
"response": "Test response",
113-
"context": "Test context",
114-
}
112+
self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
113+
examples = {
114+
"instruction": "Test instruction",
115+
"response": "Test response",
116+
"context": "Test context",
117+
}
115118

116119
chat_example = [
117120
{
118121
"role": "user",
119122
"content": "###Instruction:\nTest instruction\n\n###context:\nTest context\n\n",
120-
121123
},
122124
{
123125
"role": "assistant",
124126
"content": "Test response\n\n",
125-
}
127+
},
126128
]
127129

128130
# Verify the format of the result
129-
expected_result = self.tokenizer.apply_chat_template(chat_example,
130-
tokenize=False,
131-
max_length=self.config.get("max_length"))
131+
expected_result = self.tokenizer.apply_chat_template(
132+
chat_example, tokenize=False, max_length=self.config.get("max_length")
133+
)
132134

133-
self.config['gpt_base_model'] = False
135+
self.config["gpt_base_model"] = False
134136
result = self.processer.tokenize_function(examples, self.tokenizer)
135-
self.assertEqual(self.tokenizer.decode(result['input_ids']), expected_result)
137+
self.assertEqual(self.tokenizer.decode(result["input_ids"]), expected_result)
136138

137139

138-
if __name__ == '__main__':
140+
if __name__ == "__main__":
139141
unittest.main()

0 commit comments

Comments
 (0)