From 558b983d30a4f157846321fd17d5e9d6dae7db6e Mon Sep 17 00:00:00 2001 From: ylfeng Date: Wed, 18 Sep 2024 06:11:10 +0800 Subject: [PATCH] 1. support flatting_packing 2. update mistral format function call 3. fix knapsack, may cause #5443 4. avoid supervised examples wrongly truncation #5426 --- src/llamafactory/cli.py | 1 - src/llamafactory/data/__init__.py | 3 +- src/llamafactory/data/collator.py | 39 +++++++- src/llamafactory/data/formatter.py | 46 ++++++++- src/llamafactory/data/preprocess.py | 8 +- .../data/processors/processor_utils.py | 6 ++ .../data/processors/supervised.py | 98 +++++++++++-------- src/llamafactory/data/template.py | 61 ++---------- src/llamafactory/data/tool_utils.py | 31 +++++- src/llamafactory/hparams/data_args.py | 7 ++ src/llamafactory/train/sft/workflow.py | 39 +++++--- 11 files changed, 224 insertions(+), 115 deletions(-) diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 8012d85549..d1053f2f5d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -28,7 +28,6 @@ from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui - USAGE = ( "-" * 70 + "\n" diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index ea1a02f20c..2161422793 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -17,17 +17,18 @@ MultiModalDataCollatorForSeq2Seq, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask, + SFTDataCollatorWithFlattingPacking, ) from .data_utils import Role, split_dataset from .loader import get_dataset from .template import TEMPLATES, Template, get_template_and_fix_tokenizer - __all__ = [ "KTODataCollatorWithPadding", "MultiModalDataCollatorForSeq2Seq", "PairwiseDataCollatorWithPadding", "SFTDataCollatorWith4DAttentionMask", + "SFTDataCollatorWithFlattingPacking", "Role", "split_dataset", "get_dataset", diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 92d86cc754..75fb937b88 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -19,8 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence import torch -from transformers import DataCollatorForSeq2Seq - +from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, default_data_collator, PreTrainedTokenizerBase if TYPE_CHECKING: from transformers import ProcessorMixin @@ -120,6 +119,42 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso return features +@dataclass +class SFTDataCollatorWithFlattingPacking(DefaultDataCollator): + r""" + Data collator for flatting packing. + """ + + tokenizer: PreTrainedTokenizerBase = None + label_pad_token_id: int = -100 + template: Optional["Template"] = None + processor: Optional["ProcessorMixin"] = None + return_position_ids: bool = True + + def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, "torch.Tensor"]: + # todo: not support multi-model + if return_tensors is None: + return_tensors = self.return_tensors + is_labels_provided = "labels" in features[0] + ret = {"input_ids": [], "labels": []} + if self.return_position_ids: + ret.update({"position_ids": []}) + for instances in features: + for input_ids, labels in zip(instances["input_ids"], instances["labels"]): + ret["input_ids"] += input_ids + if is_labels_provided: + ret["labels"] += [self.label_pad_token_id] + labels[1:] + else: + ret["labels"] += [self.label_pad_token_id] + input_ids[1:] + if self.return_position_ids: + ret["position_ids"] += list(range(len(input_ids))) + + assert len(ret["input_ids"]) == len(ret["labels"]) + + features: Dict[str, "torch.Tensor"] = default_data_collator([ret], return_tensors) + return features + + @dataclass class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): r""" diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index f8b3979a7c..0efd267148 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -23,7 +23,6 @@ from .data_utils import SLOTS from .tool_utils import get_tool_utils - if TYPE_CHECKING: from .tool_utils import FunctionCall @@ -129,6 +128,51 @@ def apply(self, **kwargs) -> SLOTS: return elements +@dataclass +class MistralFunctionFormatter(Formatter): + @override + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + functions: List[Tuple[str, str]] = [] + try: + tool_calls = json.loads(content) + if not isinstance(tool_calls, list): # parallel function call + tool_calls = [tool_calls] + + for tool_call in tool_calls: + functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + + except json.JSONDecodeError: + functions = [] + + elements = [] + for name, arguments in functions: + elements.append(f""""{{"name":"{name}","arguments":{arguments}}}""") + elements = ["[TOOL_CALLS] [" + ", ".join(elements) + "]"] + + return elements + + +@dataclass +class MistralObservationFormatter(Formatter): + def __post_init__(self): + self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots + + @override + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + tool_results: List[Tuple[str, str]] + try: + tool_results = [json.dumps(result) for result in json.loads(content)] + except json.JSONDecodeError: + tool_results = [] + + elements = [] + for content in tool_results: + elements.append(f"[TOOL_RESULTS] {{\"content\":{content}}}[/TOOL_RESULTS]") + return ["".join(elements)] + + @dataclass class ToolFormatter(Formatter): def __post_init__(self): diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 9f015b3823..954a35f57e 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -22,10 +22,10 @@ preprocess_packed_supervised_dataset, preprocess_supervised_dataset, print_supervised_dataset_example, + print_flatting_supervised_dataset_example, ) from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example - if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin @@ -78,8 +78,10 @@ def __init__(self, data, **kwargs): processor=processor, data_args=data_args, ) - - print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) + if data_args.packing and data_args.flatting_packing: + print_function = partial(print_flatting_supervised_dataset_example, tokenizer=tokenizer) + else: + print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) elif stage == "rm": preprocess_func = partial( preprocess_pairwise_dataset, diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index 8e13d100bc..b7297df34f 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -28,6 +28,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: r""" An efficient greedy algorithm with binary search for the knapsack problem. """ + # filter out numbers that are larger than the capacity + numbers = [number for number in numbers if number <= capacity] numbers.sort() # sort numbers in ascending order for binary search knapsacks = [] @@ -43,6 +45,10 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: remaining_capacity -= numbers[index] # update the remaining capacity current_knapsack.append(numbers.pop(index)) # add the number to knapsack + # avoid endless loop + if remaining_capacity == capacity: + break + knapsacks.append(current_knapsack) return knapsacks diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 666256407a..1406f16a15 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import itertools from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple @@ -19,7 +19,6 @@ from ...extras.logging import get_logger from .processor_utils import greedy_knapsack, infer_seqlen - if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin @@ -27,7 +26,6 @@ from ..mm_plugin import ImageInput, VideoInput from ..template import Template - logger = get_logger(__name__) @@ -48,18 +46,12 @@ def _encode_supervised_example( messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor) input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) - total_length = len(input_ids) + (1 if template.efficient_eos else 0) if mask_history: encoded_pairs = encoded_pairs[::-1] # high priority for last turns for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): - if total_length >= cutoff_len: - break - - source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) - source_ids = source_ids[:source_len] - target_ids = target_ids[:target_len] - total_length += source_len + target_len + source_len = len(source_ids) + target_len = len(target_ids) if train_on_prompt: source_label = source_ids @@ -132,13 +124,16 @@ def preprocess_packed_supervised_dataset( processor: Optional["ProcessorMixin"], data_args: "DataArguments", ) -> Dict[str, List[Any]]: - # TODO: use `position_ids` to achieve packing # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` valid_num = 0 + invalid_num = 0 batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], [] lengths = [] length2indexes = defaultdict(list) + + # reserved for the padding token / flatting_packing don't need + num_reserved = 0 if data_args.flatting_packing else 1 for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) @@ -154,13 +149,13 @@ def preprocess_packed_supervised_dataset( template=template, tokenizer=tokenizer, processor=processor, - cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token + cutoff_len=data_args.cutoff_len - num_reserved, train_on_prompt=data_args.train_on_prompt, mask_history=data_args.mask_history, ) length = len(input_ids) - if length > data_args.cutoff_len: - logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) + if length > data_args.cutoff_len - num_reserved: + invalid_num += 1 else: lengths.append(length) length2indexes[length].append(valid_num) @@ -170,36 +165,52 @@ def preprocess_packed_supervised_dataset( batch_videos.append(examples["_videos"][i] or []) valid_num += 1 + if invalid_num > 0: + logger.warning( + "Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved) + ) + model_inputs = defaultdict(list) - knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token + knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token for knapsack in knapsacks: packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_images, packed_videos = [], [] - for i, length in enumerate(knapsack): - index = length2indexes[length].pop() - packed_input_ids += batch_input_ids[index] - packed_labels += batch_labels[index] - packed_images += batch_images[index] - packed_videos += batch_videos[index] - if data_args.neat_packing: - packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 - else: - packed_attention_masks += [1] * len(batch_input_ids[index]) - - if len(packed_input_ids) < data_args.cutoff_len: - pad_length = data_args.cutoff_len - len(packed_input_ids) - packed_input_ids += [tokenizer.pad_token_id] * pad_length - packed_labels += [IGNORE_INDEX] * pad_length - if data_args.neat_packing: - packed_attention_masks += [0] * pad_length - else: - packed_attention_masks += [1] * pad_length # more efficient flash_attn - - if len(packed_input_ids) != data_args.cutoff_len: - raise ValueError("The length of packed example should be identical to the cutoff length.") + + if data_args.flatting_packing: + for i, length in enumerate(knapsack): + index = length2indexes[length].pop() + packed_input_ids.append(batch_input_ids[index]) + packed_labels.append(batch_labels[index]) + packed_images.append(batch_images[index]) + packed_videos.append(batch_videos[index]) + else: + for i, length in enumerate(knapsack): + index = length2indexes[length].pop() + packed_input_ids += batch_input_ids[index] + packed_labels += batch_labels[index] + packed_images += batch_images[index] + packed_videos += batch_videos[index] + if data_args.neat_packing: + packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 + else: + packed_attention_masks += [1] * len(batch_input_ids[index]) + + # flatting_packing don't need attention masks + if len(packed_input_ids) < data_args.cutoff_len: + pad_length = data_args.cutoff_len - len(packed_input_ids) + packed_input_ids += [tokenizer.pad_token_id] * pad_length + packed_labels += [IGNORE_INDEX] * pad_length + if data_args.neat_packing: + packed_attention_masks += [0] * pad_length + else: + packed_attention_masks += [1] * pad_length # more efficient flash_attn + + # flatting packing don't need pad + if len(packed_input_ids) != data_args.cutoff_len: + raise ValueError("The length of packed example should be identical to the cutoff length.") + model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["input_ids"].append(packed_input_ids) - model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["labels"].append(packed_labels) model_inputs["images"].append(packed_images or None) model_inputs["videos"].append(packed_videos or None) @@ -213,3 +224,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: " print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) + + +def print_flatting_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, itertools.chain(*example["labels"]))) + input_ids = list(itertools.chain(*example["input_ids"])) + print("input_ids:\n{}".format(input_ids)) + print("inputs:\n{}".format(tokenizer.decode(input_ids, skip_special_tokens=False))) + print("label_ids:\n{}".format(list(itertools.chain(*example["labels"])))) + print("labels:\n{}".format(tokenizer.decode(valid_labels), skip_special_tokens=False)) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 54da4757f7..89d19be01a 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -21,9 +21,9 @@ from ..extras.logging import get_logger from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter +from .formatter import MistralFunctionFormatter, MistralObservationFormatter from .mm_plugin import get_mm_plugin - if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -31,7 +31,6 @@ from .formatter import SLOTS, Formatter from .mm_plugin import BasePlugin - logger = get_logger(__name__) @@ -416,7 +415,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ), ) - _register_template( name="aquila", format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), @@ -429,7 +427,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, ) - _register_template( name="atom", format_user=StringFormatter( @@ -438,21 +435,18 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), ) - _register_template( name="baichuan", format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), efficient_eos=True, ) - _register_template( name="baichuan2", format_user=StringFormatter(slots=["{{content}}"]), efficient_eos=True, ) - _register_template( name="belle", format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), @@ -460,13 +454,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) - _register_template( name="bluelm", format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), ) - _register_template( name="breeze", format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), @@ -474,7 +466,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, ) - _register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), @@ -483,7 +474,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, ) - _register_template( name="chatglm3", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), @@ -499,7 +489,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, ) - _register_template( name="chatml", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -510,7 +499,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="chatml_de", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -522,13 +510,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="codegeex2", format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), ) - _register_template( name="codegeex4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), @@ -545,7 +531,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, ) - _register_template( name="cohere", format_user=StringFormatter( @@ -560,14 +545,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) - _register_template( name="cpm", format_user=StringFormatter(slots=["<用户>{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) - _register_template( name="cpm3", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -576,7 +559,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: stop_words=["<|im_end|>"], ) - _register_template( name="dbrx", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -602,7 +584,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="deepseek", format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), @@ -610,7 +591,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) - _register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), @@ -625,7 +605,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ), ) - _register_template( name="default", format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), @@ -633,13 +612,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: format_separator=EmptyFormatter(slots=["\n"]), ) - _register_template( name="empty", efficient_eos=True, ) - _register_template( name="falcon", format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), @@ -647,14 +624,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, ) - _register_template( name="fewshot", format_separator=EmptyFormatter(slots=["\n\n"]), efficient_eos=True, ) - _register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), @@ -666,7 +641,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, ) - _register_template( name="glm4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), @@ -680,7 +654,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, ) - _register_template( name="intern", format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), @@ -691,7 +664,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, # internlm tokenizer cannot set eos_token_id ) - _register_template( name="intern2", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -702,14 +674,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id ) - _register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), ) - _register_template( name="llama2_zh", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), @@ -717,7 +687,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: default_system="You are a helpful assistant. 你是一个乐于助人的助手。", ) - _register_template( name="llama3", format_user=StringFormatter( @@ -742,7 +711,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="llava", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), @@ -753,28 +721,29 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: mm_plugin=get_mm_plugin(name="llava", image_token=""), ) - _register_template( name="mistral", format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_assistant=StringFormatter(slots=[" {{content}}"]), # mistral add space here format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + format_function=MistralFunctionFormatter(slots=[], tool_format="mistral"), + format_observation=MistralObservationFormatter(tool_format="mistral"), + format_tools=ToolFormatter(tool_format="mistral"), + efficient_eos=True, ) - _register_template( name="olmo", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_prefix=EmptyFormatter(slots=[{"eos_token"}]), ) - _register_template( name="openchat", format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) - _register_template( name="openchat-3.6", format_user=StringFormatter( @@ -790,14 +759,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="orion", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) - _register_template( name="paligemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), @@ -810,7 +777,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: mm_plugin=get_mm_plugin(name="paligemma", image_token=""), ) - _register_template( name="phi", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), @@ -821,7 +787,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -833,7 +798,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="qwen2_vl", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -846,7 +810,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), ) - _register_template( name="sailor", format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), @@ -860,7 +823,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="solar", format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), @@ -868,7 +830,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: efficient_eos=True, ) - _register_template( name="starchat", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), @@ -878,7 +839,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="telechat", format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]), @@ -887,7 +847,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="vicuna", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), @@ -897,7 +856,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ), ) - _register_template( name="xuanyuan", format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), @@ -908,13 +866,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ), ) - _register_template( name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), ) - _register_template( name="yayi", format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), @@ -934,7 +890,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: stop_words=["<|End|>"], ) - _register_template( name="yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -944,7 +899,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="yi_vl", format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), @@ -961,7 +915,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: mm_plugin=get_mm_plugin(name="llava", image_token=""), ) - _register_template( name="yuan", format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), @@ -970,7 +923,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: replace_eos=True, ) - _register_template( name="zephyr", format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]), @@ -978,7 +930,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: default_system="You are Zephyr, a helpful assistant.", ) - _register_template( name="ziya", format_user=StringFormatter(slots=[":{{content}}\n:"]), diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 88027ba69b..8ad284bd1b 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -23,7 +23,6 @@ from .data_utils import SLOTS - DEFAULT_TOOL_PROMPT = ( "You have access to the following tools:\n{tool_text}" "Use the following format if using a tool:\n" @@ -34,12 +33,12 @@ "```\n" ) - GLM4_TOOL_PROMPT = ( "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}" ) +MISTRAL_TOOL_PROMPT = "[AVAILABLE_TOOLS] {tools} [/AVAILABLE_TOOLS]" FunctionCall = namedtuple("FunctionCall", ["name", "arguments"]) @@ -168,8 +167,36 @@ def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: return [(tool_name, json.dumps(arguments, ensure_ascii=False))] +class MistralToolUtils(ToolUtils): + @override + @staticmethod + def get_function_slots() -> SLOTS: + return ["[TOOL_RESULTS] ", "{\"content\": {{results}}}", "[/TOOL_RESULTS]"] + + @override + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tools = [{"type": "function", "function": tool} for tool in tools] + return MISTRAL_TOOL_PROMPT.format(tools=tools) + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: + if "\n" not in content: + return content + + tool_name, tool_input = content.split("\n", maxsplit=1) + try: + arguments = json.loads(tool_input) + except json.JSONDecodeError: + return content + + return [(tool_name, json.dumps(arguments, ensure_ascii=False))] + + TOOLS = { "default": DefaultToolUtils(), + "mistral": MistralToolUtils(), "glm4": GLM4ToolUtils(), } diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 1adcf2d0df..98f2ca5753 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -105,6 +105,10 @@ class DataArguments: default=False, metadata={"help": "Enable sequence packing without cross-attention."}, ) + flatting_packing: bool = field( + default=False, + metadata={"help": "Enable sequence packing with flattening, need flash atten."} + ) tool_format: Optional[str] = field( default=None, metadata={"help": "Tool format to use for constructing function calling examples."}, @@ -148,3 +152,6 @@ def split_arg(arg): if self.mask_history and self.train_on_prompt: raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") + + if self.neat_packing and self.flatting_packing: + raise ValueError("`neat_packing` is incompatible with `flatting_packing`.") diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 43a9aef16f..8cc3674896 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -17,21 +17,24 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer +from ...data import SFTDataCollatorWith4DAttentionMask, SFTDataCollatorWithFlattingPacking, get_dataset, \ + get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss +from ...extras.logging import get_logger from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor from .trainer import CustomSeq2SeqTrainer - if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments +logger = get_logger(__name__) + def run_sft( model_args: "ModelArguments", @@ -50,15 +53,29 @@ def run_sft( if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction - data_collator = SFTDataCollatorWith4DAttentionMask( - template=template, - pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention - label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, - block_diag_attn=model_args.block_diag_attn, - attn_implementation=getattr(model.config, "_attn_implementation", None), - compute_dtype=model_args.compute_dtype, - **tokenizer_module, - ) + if ( + data_args.packing and + data_args.flatting_packing and + (getattr(model.config, "_attn_implementation", None) != "flash_attention_2") + ): + logger.warning("The `flatting_packing` only support `flash_attention_2`! Maybe cause Out of memory!") + + if (data_args.packing and data_args.flatting_packing): + data_collator = SFTDataCollatorWithFlattingPacking( + template=template, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + **tokenizer_module, + ) + else: + data_collator = SFTDataCollatorWith4DAttentionMask( + template=template, + pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + block_diag_attn=model_args.block_diag_attn, + attn_implementation=getattr(model.config, "_attn_implementation", None), + compute_dtype=model_args.compute_dtype, + **tokenizer_module, + ) # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len