From 9b934bd46cf8e8d17b2ddf5910f848cb33bfaba8 Mon Sep 17 00:00:00 2001 From: "vejvarm@freya" Date: Tue, 28 Nov 2023 13:55:29 +0900 Subject: [PATCH] refactor: count total samples per split and save to stats.json --- lab/count_triples.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lab/count_triples.py b/lab/count_triples.py index b471157..8a4f5c3 100644 --- a/lab/count_triples.py +++ b/lab/count_triples.py @@ -36,11 +36,12 @@ def categorize_by_num_of_triples(input_dir, splits=("test", "train", "val")): def main(args): total, num_triples_counter, convo_len_counter = categorize_by_num_of_triples(args.input) - stats = {**num_triples_counter} + stats = dict(**num_triples_counter) + stats["total_samples"] = {k: sum(c.values()) for k, c in num_triples_counter.items()} total_counter = Counter() [total_counter.update(ctr) for ctr in num_triples_counter.values()] - stats["total"] = total_counter + stats["total_triples"] = total_counter stats["convo_len"] = convo_len_counter print(f"FILES TOTAL: {total}")