Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[minor] Add DatatypeHelper func to split data by fold #3892

Merged
merged 4 commits into from
Aug 5, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions parlai/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,15 @@ def is_streaming(cls, datatype: str) -> bool:
return 'stream' in datatype

@classmethod
def split_domains_by_fold(
def split_data_by_fold(
cls,
fold: str,
domains: List[List],
data: List,
train_frac: float,
valid_frac: float,
test_frac: float,
seed: int = 42,
):
"""
Need to be careful about how we setup random to not leak examples between trains
moyapchen marked this conversation as resolved.
Show resolved Hide resolved
if we're in a scenario where a single dataset has different ways of mixing +
matching subcomponents.
"""
assert train_frac + valid_frac + test_frac == 1
if "train" in fold:
start = 0.0
Expand All @@ -140,9 +135,30 @@ def split_domains_by_fold(
start = train_frac + valid_frac
end = 1.0

random.Random(seed).shuffle(data)
return data[int(start * len(data)) : int(end * len(data))]

@classmethod
def split_domains_by_fold(
moyapchen marked this conversation as resolved.
Show resolved Hide resolved
cls,
fold: str,
domains: List[List],
train_frac: float,
valid_frac: float,
test_frac: float,
seed: int = 42,
):
"""
Need to be careful about how we setup random to not leak examples between trains
moyapchen marked this conversation as resolved.
Show resolved Hide resolved
if we're in a scenario where a single dataset has different ways of mixing +
matching subcomponents.
"""
result = []
for domain in domains:
random.Random(seed).shuffle(domain)
result.extend(domain[int(start * len(domain)) : int(end * len(domain))])
result.extend(
cls.split_data_by_fold(
fold, domain, train_frac, valid_frac, test_frac, seed
)
)
random.Random(seed).shuffle(result)
return result