Skip to content

Commit

Permalink
feat: Tweak data sizes for langid datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
bdewilde committed Mar 24, 2023
1 parent 538211d commit 4a3640e
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions scripts/prepare_langid_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def main():
nlp.tokenizer = tokenizer_func(1000, True)(nlp)

print("converting train records to docs ...")
train_docbin = DocBin(docs=(convert_record(nlp, record) for record in train_data))
if args.save_dir:
train_docbin.to_disk(args.save_dir / "train.spacy")
train_docs = (convert_record(nlp, record) for record in train_data)
for i, docs_batch in enumerate(itertoolz.partition_all(50_000, train_docs)):
train_docbin = DocBin(docs=docs_batch)
if args.save_dir:
train_dir = args.save_dir / "train"
train_dir.mkdir(exist_ok=True)
train_docbin.to_disk(train_dir / f"{i}.spacy")

print("saving train labels to disk ...")
labels = sorted(set(lang for _, lang in train_data))
Expand All @@ -61,9 +65,13 @@ def main():
json.dump(labels, f)

print("converting test records to docs ...")
test_docbin = DocBin(docs=(convert_record(nlp, record) for record in test_data))
if args.save_dir:
test_docbin.to_disk(args.save_dir / "test.spacy")
test_docs = (convert_record(nlp, record) for record in test_data)
for i, docs_batch in enumerate(itertoolz.partition_all(50_000, test_docs)):
test_docbin = DocBin(docs=docs_batch)
if args.save_dir:
test_dir = args.save_dir / "test"
test_dir.mkdir(exist_ok=True)
test_docbin.to_disk(test_dir / f"{i}.spacy")


def add_and_parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -99,7 +107,7 @@ def add_and_parse_args() -> argparse.Namespace:
parser.add_argument(
"--min-obs",
type=int,
default=300,
default=500,
help="minimum number of observations -- (text, lang) pairs -- in a language "
"for it to be included in the training dataset",
)
Expand Down Expand Up @@ -158,7 +166,6 @@ def load_and_agg_data(
wili = textacy.lang_id._datasets.Wili2018Dataset(src_root_dir.joinpath("wili"))
wili.download(force=force)
wili_data = wili.load(iso_lang_map, min_len=min_text_len)
random.shuffle(udhr_data)

tatoeba = textacy.lang_id._datasets.TatoebaDataset(src_root_dir.joinpath("tatoeba"))
tatoeba.download(force=force)
Expand All @@ -170,26 +177,22 @@ def load_and_agg_data(

# aggregate and sample datasets
agg_data = (
udhr_data
+ wili_data
+ get_random_sample(tatoeba_data, 200000, stratify=True, random_state=seed)
+ get_random_sample(ud_data, 200000, stratify=True, random_state=seed)
udhr_data # only has ~12k examples
+ get_random_sample(wili_data, 100_000, stratify=True, random_state=seed)
+ get_random_sample(tatoeba_data, 100_000, stratify=True, random_state=seed)
+ get_random_sample(ud_data, 100_000, stratify=True, random_state=seed)
# add additional examples for hard-to-distinguish language groups
+ get_random_sample(dslcc_data, 50000, stratify=True, random_state=seed)
+ get_random_sample(dslcc_data, 50_000, stratify=True, random_state=seed)
# add some extra english examples, since there's apparently a fair amount
# of english sprinkled throughout other languages, causing meh performance
+ get_random_sample(
[item for item in tatoeba_data if item[1] == "en"],
10000,
10_000,
stratify=False,
random_state=seed,
)
)

# agg_data = get_random_sample(
# tatoeba_data, 1_000_000, stratify=True, random_state=seed
# )

agg_data = filter_data_by_lang_count(agg_data, min_obs)

return agg_data
Expand Down

0 comments on commit 4a3640e

Please sign in to comment.