Skip to content

Commit

Permalink
refactor: count total samples per split and save to stats.json
Browse files Browse the repository at this point in the history
  • Loading branch information
vejvarm committed Nov 28, 2023
1 parent 6b14800 commit 9b934bd
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions lab/count_triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ def categorize_by_num_of_triples(input_dir, splits=("test", "train", "val")):
def main(args):
total, num_triples_counter, convo_len_counter = categorize_by_num_of_triples(args.input)

stats = {**num_triples_counter}
stats = dict(**num_triples_counter)
stats["total_samples"] = {k: sum(c.values()) for k, c in num_triples_counter.items()}

total_counter = Counter()
[total_counter.update(ctr) for ctr in num_triples_counter.values()]
stats["total"] = total_counter
stats["total_triples"] = total_counter
stats["convo_len"] = convo_len_counter

print(f"FILES TOTAL: {total}")
Expand Down

0 comments on commit 9b934bd

Please sign in to comment.