Skip to content

Commit

Permalink
Add statistics for CGLM dataset (#516)
Browse files Browse the repository at this point in the history
We add some statistic collection for the different CGLM variants and
helpful plotting scripts.

PR 1/n of porting SIGMOD changes to main. This one should probably not
require an indepth review, mostly stuff I needed. Not the prettiest
code, but why bother using `defaultdict`s :D

---------

Co-authored-by: Xianzhe Ma <xianzma@gmail.com>
  • Loading branch information
MaxiBoether and XianzheMa authored Jun 26, 2024
1 parent 72ffdca commit 3e86231
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 4 deletions.
2 changes: 2 additions & 0 deletions benchmark/cglm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Note that there are different versions of CGLM you can generate:

The script will tell you at the end how many classes your dataset contains.
Note that the size of the dataset is not consistent with what is reported in the initial paper on CGLM, but since their code on metadata processing is not open source, we cannot investigate the difference here.
The script will also generate some statistics over the generated dataset.
You might find the `analyze_centeredness.py` file helpful to understand how many classes are only prevalent in a single year (how much they are dominated by their mode), and `analyze_samples_per_year_and_class.py` can give for each class a visual intuition of the sample distribution over time.

## Regenerate the metadata
This paragraph is only relevant if you are interested in how we generated the `cglm_labels_timestamps_clean.csv` file.
Expand Down
50 changes: 50 additions & 0 deletions benchmark/cglm/analyze_centeredness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import json

import matplotlib.pyplot as plt

# Load the statistics from the JSON file
with open('hierarchy_stats.json', 'r') as f:
stats = json.load(f)

split = "train"


# Get the list of all classes
all_classes = stats[split]['per_class'].keys()
num_classes = len(set(all_classes))

print(f"there are {len(set(all_classes))} classes")

# Get the years available in the dataset
years = sorted(stats[split]['per_year_and_class'].keys())

num_years = len(years)

results = [0 for _ in range(num_years + 1) ]

# Plot the number of samples per year for each class
for class_name in [str(i) for i in range(num_classes)]:
samples_per_year = [stats[split]['per_year_and_class'].get(year, {}).get(class_name, 0) for year in years]

max_samples = -1
max_year_idx = -1
for year_idx, samples in enumerate(samples_per_year):
if samples > max_samples:
max_year_idx = year_idx
max_samples = samples

for i in range(len(results)):
base_idx = max_year_idx + i - (num_years // 2)
if base_idx < 0 or base_idx >= len(samples_per_year):
continue
results[i] += samples_per_year[base_idx] / sum(samples_per_year)

print(results)
plt.figure(figsize=(10, 5))
plt.bar(range( - num_years // 2, (num_years // 2) + 1,1), results, color='cornflowerblue')
plt.title(f'Number of Samples per Year for Class: {class_name}')
plt.xlabel('Year')
plt.ylabel('Number of Samples')
plt.xticks(range( - num_years // 2, (num_years // 2) + 1,1), rotation=45)
plt.tight_layout()
plt.show()
32 changes: 32 additions & 0 deletions benchmark/cglm/analyze_samples_per_year_and_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json

import matplotlib.pyplot as plt

# Load the statistics from the JSON file
with open('hierarchy_stats.json', 'r') as f:
stats = json.load(f)

split = "train"
num_classes = 79


# Get the list of all classes
all_classes = stats[split]['per_class'].keys()

print(f"there are {len(set(all_classes))} classes")

# Get the years available in the dataset
years = sorted(stats[split]['per_year_and_class'].keys())

# Plot the number of samples per year for each class
for class_name in [str(i) for i in range(num_classes)]:
samples_per_year = [stats[split]['per_year_and_class'].get(year, {}).get(class_name, 0) for year in years]

plt.figure(figsize=(10, 5))
plt.bar(years, samples_per_year, color='cornflowerblue')
plt.title(f'Number of Samples per Year for Class: {class_name}')
plt.xlabel('Year')
plt.ylabel('Number of Samples')
plt.xticks(years, rotation=45)
plt.tight_layout()
plt.show()
38 changes: 34 additions & 4 deletions benchmark/cglm/data_generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import json
import logging
import os
import pathlib
Expand Down Expand Up @@ -85,14 +86,13 @@ def main():

label_to_new_label = {old_label: label for label, old_label in enumerate(df[label_column].unique())}
df['label'] = df[label_column].map(label_to_new_label)

df["year"] = df["upload_date"].apply(lambda x: datetime.fromtimestamp(x).year)
print(f"We got {df.shape[0]} samples with {len(label_to_new_label)} classes for this configuration. Generating subset.")

if args.eval_split:
if args.eval_split == "uniform":
train_df, eval_df = train_test_split(df, test_size=0.1, random_state=42)
elif args.eval_split == "yearly":
df["year"] = df["upload_date"].apply(lambda x: datetime.fromtimestamp(x).year)
train_dfs, eval_dfs = [], []
for _year, group in df.groupby("year"):
train_group, eval_group = train_test_split(group, test_size=0.1, random_state=42)
Expand All @@ -105,7 +105,10 @@ def main():
else:
loop_iterator = [("train", df, args.output / identifier / "train")]

overall_stats = {}
for split, split_df, output_dir in loop_iterator:
split_stats = {"total_samples": 0, "total_classes": len(label_to_new_label), "per_year": {}, "per_class": {}, "per_year_and_class": {}}

output_dir.mkdir(parents=True, exist_ok=True)

logger.info(f"Generating {split} split with {len(split_df)} samples...")
Expand All @@ -115,6 +118,29 @@ def main():
timestamp = row["upload_date"]
file_path = args.sourcedir / f"{id[0]}/{id[1]}/{id[2]}/{id}.jpg"

# Generate stats
year = row["year"]
split_stats["total_samples"] += 1
if year not in split_stats["per_year"]:
split_stats["per_year"][year] = 1
else:
split_stats["per_year"][year] += 1

if label not in split_stats["per_class"]:
split_stats["per_class"][label] = 1
else:
split_stats["per_class"][label] += 1

if year not in split_stats["per_year_and_class"]:
split_stats["per_year_and_class"][year] = {}
split_stats["per_year_and_class"][year][label] = 1
else:
if label not in split_stats["per_year_and_class"][year]:
split_stats["per_year_and_class"][year][label] = 1
else:
split_stats["per_year_and_class"][year][label] += 1

# Create files
if not file_path.exists():
logger.error(f"File {file_path} is supposed to exist, but it does not. Skipping...")
continue
Expand All @@ -132,6 +158,12 @@ def main():
file.write(str(int(label)))
os.utime(output_dir / f"{id}.label", (timestamp, timestamp))

overall_stats[split] = split_stats

with open(args.output / identifier / "dataset_stats.json", "w") as f:
json.dump(overall_stats, f, indent=4)


if args.dummy:
dummy_path = args.output / identifier / "train" / "dummy.jpg"
shutil.copy(file_path, dummy_path) # just use the last file_path
Expand All @@ -147,5 +179,3 @@ def main():

if __name__ == "__main__":
main()


0 comments on commit 3e86231

Please sign in to comment.