Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

[tasks/huggingface] Way to add to message but not "text" #4516

Merged
merged 1 commit into from
Apr 27, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 9 additions & 27 deletions parlai/tasks/huggingface/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
import datasets


class SplitsMappingDict(TypedDict):
train: str
valid: str
test: str


class AbstractHuggingFaceTeacher(DialogTeacher):
"""
Abstract parent class for HuggingFace teachers. Extend this class and specify the
Expand All @@ -29,6 +23,7 @@ class AbstractHuggingFaceTeacher(DialogTeacher):
hf_path = path parameter passed into hugging face load_dataset function
hf_name = name parameter passed into hugging face load_dataset function
hf_text_fields = list of names of the data fields from the dataset to be included in the text/query
hf_message_fields = [optional] list of names of the data fields from the dataset to be included in the message object but *not* text
hf_label_field = name of the data field from the hf dataset that specifies the label of the episode
hf_splits_mapping = dictionary mapping with the keys 'train', 'valid', and 'test', that map to the
names of the splits of the hf dataset.
Expand All @@ -52,26 +47,6 @@ def _path(self, opt):
)
return os.path.join(opt['datapath'], 'huggingface', self.hf_path, self.fold)

@property
def hf_path(self) -> str:
raise NotImplementedError

@property
def hf_name(self) -> Optional[str]:
return None

@property
def hf_text_fields(self) -> List[str]:
raise NotImplementedError

@property
def hf_label_field(self) -> str:
raise NotImplementedError

@property
def hf_splits_mapping(self) -> SplitsMappingDict:
raise NotImplementedError

def _get_text_value(self, row) -> Tuple[str, Dict[str, str]]:
"""
return the constructed text query and dict mapping text field names to values.
Expand All @@ -83,7 +58,14 @@ def _get_text_value(self, row) -> Tuple[str, Dict[str, str]]:
if text_part is None:
raise KeyError(f'Feature "{col}" not found in data.')
text_dict[col] = text_part
return '\n'.join(text_dict.values()), text_dict
query = '\n'.join(text_dict.values())
if hasattr(self, "hf_message_fields"):
for col in self.hf_message_fields:
text_part = row.get(col)
if text_part is None:
raise KeyError(f'Feature "{col}" not found in data.')
text_dict[col] = text_part
return query, text_dict

def _get_label_value(self, row):
"""
Expand Down