Skip to content

Commit b397104

Browse files
chengeharrisonXu YuanchenCamille7777
authored
[Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)
* Add finetuning Colossal-Llama-2 example * Add finetuning Colossal-Llama-2 example 2 * Add finetuning Colossal-Llama-2 example and support NEFTuning * Add inference example and refine neftune * Modify readme file * update the imports --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
1 parent 3dbbf83 commit b397104

File tree

9 files changed

+1036
-19
lines changed

9 files changed

+1036
-19
lines changed

applications/Colossal-LLaMA-2/README.md

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
- [Performance Evaluation](#performance-evaluation)
1212
- [Examples](#examples)
1313
- [Training Logs](#training-logs)
14-
- [Import from Transformers (Inference)](#import-from-transformers-inference)
14+
- [Inference](#inference)
15+
- [Import from HuggingFace](#import-from-huggingface)
16+
- [Import from Modelscope](#import-from-modelscope)
17+
- [Quick Start](#quick-start)
1518
- [Usage](#usage)
1619
- [Install](#install)
1720
- [0. Pre-requisite](#0-pre-requisite)
@@ -21,8 +24,14 @@
2124
- [1. Init Tokenizer Preparation](#1-init-tokenizer-preparation)
2225
- [2. Init Model Preparation](#2-init-model-preparation)
2326
- [3. Data Preparation](#3-data-preparation)
27+
- [3.1 Data for Pretraining](#31-data-for-pretraining)
28+
- [3.2 Data for Supervised Fine-tuning](#32-data-for-supervised-fine-tuning)
2429
- [4. Command Line Arguments for Training](#4-command-line-arguments-for-training)
30+
- [4.1 Arguments for Pretraining](#41-arguments-for-pretraining)
31+
- [4.2 Arguments for Supervised Fine-tuning](#42-arguments-for-supervised-fine-tuning)
2532
- [5. Running Command](#5-running-command)
33+
- [5.1 Command for Pretraining](#51-command-for-pretraining)
34+
- [5.2 Command for Supervised Fine-tuning](#52-command-for-supervised-fine-tuning)
2635
- [Technical Insights](#technical-insights)
2736
- [Data](#data)
2837
- [Tokenizer](#tokenizer)
@@ -117,7 +126,8 @@ We also recorded the training logs for the experiment
117126
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/trainingLossByTokens.jpeg?raw=true" width=600/>
118127
</p>
119128

120-
### Import from Transformers (Inference)
129+
### Inference
130+
#### Import from HuggingFace
121131
To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code:
122132
```Python
123133
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -135,14 +145,15 @@ pred = model.generate(**inputs,
135145
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):])
136146
```
137147

148+
#### Import from Modelscope
138149
You can also load our model using modelscope, use the following code:
139150
```Python
140151
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
141152
model_dir = snapshot_download('colossalai/Colossal-LLaMA-2-7b-base', revision='v1.0.1')
142153
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", trust_remote_code=True)
143154
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True).eval()
144-
generation_kwargs = {"max_new_tokens": 256,
145-
"top_p": 0.95,
155+
generation_kwargs = {"max_new_tokens": 256,
156+
"top_p": 0.95,
146157
"temperature": 0.3
147158
}
148159
input = '离离原上草,'
@@ -153,6 +164,30 @@ print(tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input):])
153164
```
154165
You can download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) or [👾Modelscope](https://modelscope.cn/models/colossalai/Colossal-LLaMA-2-7b-base/summary).
155166

167+
#### Quick Start
168+
You can run [`inference_example.py`](inference_example.py) to quickly start the inference of our base model by loading model weights from HF.
169+
170+
Command to run the script:
171+
```bash
172+
python inference_example.py \
173+
--model_path "<HF_REPO_NAME_OR_LOCAL_PATH_TO_MODEL>" \
174+
--device "cuda:0" \
175+
--max_new_tokens 512 \
176+
--do_sample True \
177+
--temperature 0.3 \
178+
--top_k 50 \
179+
--top_p 0.95 \
180+
--input_txt "YOUR_PROMPT_OR_QUESTION"
181+
```
182+
Here is details about CLI arguments:
183+
* Model path: `--model_path`. HF repo name or local path of the model.
184+
* Device: `--device`. Set the device.
185+
* Max new tokens: `--max_new_tokens`. Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
186+
* Do sample: `--do_sample`. Set whether or not to use sampling.
187+
* Temperature: `--temperature`. Set temperature value.
188+
* Top_k: `--top_k`. Set top_k value for top-k-filtering.
189+
* Top_p: `--top_p`. Set top_p value for generation.
190+
* Input_txt: `--input_txt`. The prompt string input to the model.
156191
## Usage
157192
### Install
158193

@@ -218,6 +253,8 @@ Here is details about CLI arguments:
218253
❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder.
219254

220255
#### 3. Data Preparation
256+
257+
##### 3.1 Data for Pretraining
221258
Raw data should be formatted as `jsonl` format. Each data point should have the following fields:
222259
* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty.
223260
* `target` (str, compulsory): Loss will be calculated.
@@ -250,7 +287,31 @@ Here is details about CLI arguments:
250287
* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
251288
* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
252289

290+
##### 3.2 Data for Supervised Fine-tuning
291+
We prepare data for supervised fine-tuning in a similar way. The main difference lies in the data format. Each data point should have the following field:
292+
* `messages` (list, compulsory): This part consists of a conversation between a human and assistant. The length of `messages` can vary and only content from `assistant` is used for calculating loss.
293+
294+
Examples:
295+
```JSON
296+
{"messages": [{"from": "human", "content": "What are the three primary colors?"}, {"from": "assistant", "content": "The three primary colors are red, blue, and yellow."}]}
297+
{"messages": [{"from": "human", "content": "解释个人电脑和服务器之间的区别。"}, {"from": "assistant", "content": "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。"}]}
298+
```
299+
300+
Command to convert jsonl dataset to arrow format is similar to the command in [3.1 Data for Pretraining](#31-data-for-pretraining). In `prepare_sft_dataset.py`, we don't concatenate different data samples.
301+
```
302+
python prepare_sft_dataset.py.py \
303+
--data_input_dirs "<JOSNL_DIR_1>,<JOSNL_DIR_2>,<JOSNL_DIR_3>" \
304+
--tokenizer_dir "<TOKENIZER_DIR>" \
305+
--data_cache_dir "jsonl_to_arrow_cache" \
306+
--data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
307+
--data_arrow_output_dir "spliced_tokenized_output_arrow" \
308+
--max_length 4096 \
309+
--num_spliced_dataset_bins 10
310+
```
311+
253312
#### 4. Command Line Arguments for Training
313+
314+
##### 4.1 Arguments for Pretraining
254315
You can use `colossalai run` to launch multi-nodes training:
255316
```bash
256317
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
@@ -288,7 +349,16 @@ Here is details about CLI arguments:
288349
* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
289350
* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
290351

352+
##### 4.2 Arguments for Supervised Fine-tuning
353+
We add support for gradient accumulation and NEFTuning for supervised fine-tuning and thus there are two more arguments apart from the arguments listed in [4.1 Arguments for Pretraining](#41-arguments-for-pretraining).
354+
355+
Here is details about CLI arguments:
356+
* Accumulation steps: `--accumulation_steps`. The default value is `8`.
357+
* NEFTuning: `--use_neft`. The default value is `False`. It can help improve the performance of chat models.
358+
291359
#### 5. Running Command
360+
361+
##### 5.1 Command for Pretraining
292362
An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment:
293363
* Create your own hostfile: `cp hostfile.example hostfile`.
294364
* Create your own bash: `cp train.example.sh train.sh`.
@@ -310,6 +380,10 @@ declare -a dataset=(
310380
"<DIR_2>/part-00000"
311381
)
312382
```
383+
384+
##### 5.2 Command for Supervised Fine-tuning
385+
An [example bash](train_sft.example.sh) is provided. The only difference with the command for pretraining is the two arguments (`--accumulation_steps` and `--use_neft`) in the script. You can refer to [4.2 Arguments for Supervised Fine-tuning](#42-arguments-for-supervised-fine-tuning) for more details.
386+
313387
## Technical Insights
314388
In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows:
315389

@@ -416,3 +490,11 @@ Applying the above process to perform knowledge transfer in any field allows for
416490
year={2023}
417491
}
418492
```
493+
```bibtex
494+
@article{jain2023neftune,
495+
title={NEFTune: Noisy Embeddings Improve Instruction Finetuning},
496+
author={Jain, Neel and Chiang, Ping-yeh and Wen, Yuxin and Kirchenbauer, John and Chu, Hong-Min and Somepalli, Gowthami and Bartoldson, Brian R and Kailkhura, Bhavya and Schwarzschild, Avi and Saha, Aniruddha and others},
497+
journal={arXiv preprint arXiv:2310.05914},
498+
year={2023}
499+
}
500+
```
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2023 lm-sys@FastChat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import dataclasses
16+
from enum import Enum, auto
17+
from typing import List
18+
19+
20+
class SeparatorStyle(Enum):
21+
ADD_BOS_EOS_TOKEN = auto()
22+
23+
24+
@dataclasses.dataclass
25+
class Conversation:
26+
system: str
27+
roles: List[str]
28+
messages: List[List[str]]
29+
offset: int
30+
sep_style: SeparatorStyle
31+
seps: List[str]
32+
33+
def clear(self):
34+
self.messages = []
35+
36+
def get_prompt(self, length: int = None):
37+
if length is None:
38+
length = len(self.messages)
39+
40+
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
41+
ret = self.system
42+
for role, message in self.messages[0:length]:
43+
if message:
44+
ret += role + ": " + self.seps[0] + message + self.seps[1]
45+
else:
46+
ret += role + ": " + self.seps[0]
47+
return ret
48+
else:
49+
raise ValueError(f"Invalid style: {self.sep_style}")
50+
51+
def save_prompt(self):
52+
if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
53+
ret = self.system
54+
for role, message in self.messages:
55+
if message:
56+
ret += role + ": " + self.seps[0] + message + self.seps[1] + "\n"
57+
else:
58+
ret += role + ": " + self.seps[0]
59+
return ret
60+
else:
61+
raise ValueError(f"Invalid style: {self.sep_style}")
62+
63+
def append_message(self, role, message):
64+
self.messages.append([role, message])
65+
66+
def copy(self):
67+
return Conversation(
68+
system=self.system,
69+
roles=self.roles,
70+
messages=[[x, y] for x, y in self.messages],
71+
offset=self.offset,
72+
sep_style=self.sep_style,
73+
seps=self.seps,
74+
)
75+
76+
def dict(self):
77+
return {
78+
"system": self.system,
79+
"roles": self.roles,
80+
"messages": self.messages,
81+
"offset": self.offset,
82+
"seps": self.seps,
83+
}
84+
85+
86+
conv = Conversation(
87+
system="A chat between a curious human and an artificial intelligence assistant. "
88+
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
89+
roles=("Human", "Assistant"),
90+
messages=[],
91+
offset=0,
92+
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
93+
seps=["<s>", "</s>"],
94+
)
95+
96+
default_conversation = conv

0 commit comments

Comments
 (0)