Skip to content

Commit

Permalink
fix aquila template, repair sft packing mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Oct 10, 2023
1 parent e1dcb8e commit be420e4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
17 changes: 12 additions & 5 deletions src/llmtuner/dsets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ def preprocess_dataset(
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)

if template is not None and template.efficient_eos and data_args.sft_packing:
raise ValueError("Current template is incompatible with packing.")

def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
Expand Down Expand Up @@ -105,9 +102,19 @@ def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for query, response, history, system in construct_example(examples):
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
if turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_ids + target_ids # TODO: try masking source_ids here
labels += source_mask + target_ids

if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]

total_length = len(input_ids)
block_size = data_args.cutoff_len
Expand Down
8 changes: 6 additions & 2 deletions src/llmtuner/extras/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def get_template_and_fix_tokenizer(


r"""
Supports: https://huggingface.co/qhduan/aquilachat-7b
Supports: https://huggingface.co/BAAI/AquilaChat-7B
"""
register_template(
name="aquila",
Expand All @@ -439,7 +439,11 @@ def get_template_and_fix_tokenizer(
),
sep=[
"###"
]
],
stop_words=[
"</s>"
],
efficient_eos=True
)


Expand Down

0 comments on commit be420e4

Please sign in to comment.