Skip to content

Commit

Permalink
Make AlpacaToMessage public. (pytorch#1785)
Browse files Browse the repository at this point in the history
  • Loading branch information
krammnic authored and mori360 committed Oct 14, 2024
1 parent ea7080b commit c8425ad
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 70 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Converts data from common schema and conversation JSON formats into a list of to
ShareGPTToMessages
OpenAIToMessages
ChosenRejectedToMessages
AlpacaToMessages

Collaters
---------
Expand Down
2 changes: 2 additions & 0 deletions torchtune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchtune.data._converters import get_openai_messages, get_sharegpt_messages
from torchtune.data._instruct_templates import InstructTemplate
from torchtune.data._messages import (
AlpacaToMessages,
ChosenRejectedToMessages,
InputOutputToMessages,
Message,
Expand All @@ -43,6 +44,7 @@
"SummarizeTemplate",
"OpenAIToMessages",
"ShareGPTToMessages",
"AlpacaToMessages",
"truncate",
"Message",
"validate_messages",
Expand Down
67 changes: 67 additions & 0 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,70 @@ def validate_messages(
f"System message at index {i} in messages, but system messages must come first"
)
last_turn = message.role


class AlpacaToMessages(Transform):
"""
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
(or equivalent fields specified in column_map) columns. User messages are formed from the
instruction + input columns and assistant messages are formed from the output column. Prompt
templating is conditional on the presence of the "input" column, and thus is handled directly
in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
due to this custom logic.
Args:
train_on_input (bool): Whether the model is trained on the user prompt or not.
Default is True.
column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input",
and "output" column names to the actual column names in the dataset. Default is None,
keeping the default column names.
"""

def __init__(
self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None
):
self.train_on_input = train_on_input
self.column_map = column_map
self.template = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
column_map = self.column_map or {}
key_input = column_map.get("input", "input")
key_instruction = column_map.get("instruction", "instruction")
key_output = column_map.get("output", "output")

if key_input in sample and sample[key_input]:
prompt = self.template["prompt_input"].format(
instruction=sample[key_instruction], input=sample[key_input]
)
else:
prompt = self.template["prompt_no_input"].format(
instruction=sample[key_instruction]
)

messages = [
Message(
role="user",
content=prompt,
masked=not self.train_on_input,
eot=True,
),
Message(
role="assistant",
content=sample[key_output],
masked=False,
eot=True,
),
]
return {"messages": messages}
73 changes: 3 additions & 70 deletions torchtune/datasets/_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,80 +6,13 @@

from functools import partial

from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Dict, Optional, Union

from torchtune.data._messages import AlpacaToMessages

from torchtune.data._messages import Message
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform


class AlpacaToMessages(Transform):
"""
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
(or equivalent fields specified in column_map) columns. User messages are formed from the
instruction + input columns and assistant messages are formed from the output column. Prompt
templating is conditional on the presence of the "input" column, and thus is handled directly
in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
due to this custom logic.
Args:
train_on_input (bool): Whether the model is trained on the user prompt or not.
Default is True.
column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input",
and "output" column names to the actual column names in the dataset. Default is None,
keeping the default column names.
"""

def __init__(
self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None
):
self.train_on_input = train_on_input
self.column_map = column_map
self.template = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
column_map = self.column_map or {}
key_input = column_map.get("input", "input")
key_instruction = column_map.get("instruction", "instruction")
key_output = column_map.get("output", "output")

if key_input in sample and sample[key_input]:
prompt = self.template["prompt_input"].format(
instruction=sample[key_instruction], input=sample[key_input]
)
else:
prompt = self.template["prompt_no_input"].format(
instruction=sample[key_instruction]
)

messages = [
Message(
role="user",
content=prompt,
masked=not self.train_on_input,
eot=True,
),
Message(
role="assistant",
content=sample[key_output],
masked=False,
eot=True,
),
]
return {"messages": messages}


def alpaca_dataset(
Expand Down

0 comments on commit c8425ad

Please sign in to comment.