Skip to content

Commit

Permalink
fix reserved label len
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Feb 4, 2024
1 parent 19d33ed commit db0ab4d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
28 changes: 24 additions & 4 deletions src/llmtuner/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ def preprocess_supervised_dataset(
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
Expand Down Expand Up @@ -143,7 +148,12 @@ def preprocess_unsupervised_dataset(
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]

input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)

if template.efficient_eos:
Expand Down Expand Up @@ -172,10 +182,20 @@ def preprocess_pairwise_dataset(
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]

prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)

if template.efficient_eos:
Expand Down
12 changes: 6 additions & 6 deletions src/llmtuner/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def encode_oneturn(
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16,
reserved_label_len: Optional[int] = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
Expand All @@ -57,7 +57,7 @@ def encode_multiturn(
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16,
reserved_label_len: Optional[int] = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
Expand Down Expand Up @@ -144,10 +144,10 @@ def _make_pairs(
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
encoded_messages[i] = encoded_messages[i][:max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
source_ids = encoded_messages[i][:max_source_len]
target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((source_ids, target_ids))

return encoded_pairs

Expand Down
6 changes: 3 additions & 3 deletions src/llmtuner/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ class DataArguments:
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
)
cutoff_len: Optional[int] = field(
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."}
default=1024, metadata={"help": "The cutoff length of the model inputs after tokenization."}
)
reserved_label_len: Optional[int] = field(
default=1, metadata={"help": "The maximum length reserved for label after tokenization."}
default=1, metadata={"help": "The minimum cutoff length reserved for label after tokenization."}
)
train_on_prompt: Optional[bool] = field(
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
Expand Down Expand Up @@ -57,7 +57,7 @@ class DataArguments:
ignore_pad_token_for_loss: Optional[bool] = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
},
)
val_size: Optional[float] = field(
Expand Down

0 comments on commit db0ab4d

Please sign in to comment.