Skip to content

Commit

Permalink
refactor: Simplify yearbook data generation (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
robinholzi authored Jun 21, 2024
1 parent 3dadeac commit 70e2a57
Showing 1 changed file with 31 additions and 40 deletions.
71 changes: 31 additions & 40 deletions benchmark/wildtime_benchmarks/data_generation_yearbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import os
import pickle
from pathlib import Path
from typing import Tuple

import numpy as np
Expand All @@ -22,14 +23,14 @@
logger = setup_logger()


def main():
parser = setup_argparser_wildtime("Yearbook")
def main() -> None:
parser = setup_argparser_wildtime("Yearbook", all_arg=False)
args = parser.parse_args()

logger.info(f"Downloading data to {args.dir}")

downloader = YearbookDownloader(args.dir)
downloader.store_data(args.all, args.dummyyear)
downloader.store_data(args.dummyyear)


class YearbookDownloader(Dataset):
Expand All @@ -38,18 +39,18 @@ class YearbookDownloader(Dataset):
drive_id = "1mPpxoX2y2oijOvW1ymiHEYd7oMu2vVRb"
file_name = "yearbook.pkl"

def __init__(self, data_dir: str):
def __init__(self, data_dir: Path):
super().__init__()
download_if_not_exists(
drive_id=self.drive_id,
destination_dir=data_dir,
destination_file_name=self.file_name,
)
datasets = pickle.load(open(os.path.join(data_dir, self.file_name), "rb"))
datasets = pickle.load(open(data_dir / self.file_name, "rb"))
self._dataset = datasets
self.data_dir = data_dir

def _get_year_data(self, year: int, create_test_data: bool) -> tuple[dict[str, list[tuple]], dict[str, int]]:
def _get_year_data(self, year: int) -> tuple[dict[str, list[tuple]], dict[str, int]]:
def get_split_by_id(split: int) -> list[Tuple]:
images = torch.FloatTensor(
np.array(
Expand All @@ -64,61 +65,51 @@ def get_split_by_id(split: int) -> list[Tuple]:
labels = torch.LongTensor(self._dataset[year][split]["labels"])
return [(images[i], labels[i]) for i in range(len(images))]

if not create_test_data:
train_size = len(get_split_by_id(0))
ds = {"train": get_split_by_id(0)}
stats = { "train": train_size }
else:
train_size = len(get_split_by_id(0))
test_size = len(get_split_by_id(1))
ds = {"train": get_split_by_id(0), "test": get_split_by_id(1)}
stats = {"train": train_size, "test": test_size}
train_size = len(get_split_by_id(0))
test_size = len(get_split_by_id(1))
all_size = len(get_split_by_id(2))
ds = {"train": get_split_by_id(0), "test": get_split_by_id(1), "all": get_split_by_id(2)}
stats = {"train": train_size, "test": test_size, "all": all_size}
return ds, stats

def __len__(self) -> int:
return len(self._dataset["labels"])

def store_data(self, create_test_data: bool, add_final_dummy_year: bool) -> None:
def store_data(self, add_final_dummy_year: bool) -> None:
# create directories
if not os.path.exists(self.data_dir):
os.mkdir(self.data_dir)

train_dir = os.path.join(self.data_dir, "train")
os.makedirs(train_dir, exist_ok=True)

if create_test_data:
test_dir = os.path.join(self.data_dir, "test")
os.makedirs(test_dir, exist_ok=True)
split_dirs = {
name: self.data_dir / name
for name in ["train", "test", "all"]
}
for dir_ in split_dirs.values():
os.makedirs(dir_, exist_ok=True)

overall_stats = {}
for year in self.time_steps:
ds, stats = self._get_year_data(year, create_test_data)
ds, stats = self._get_year_data(year)
overall_stats[year] = stats
self.create_binary_file(ds["train"],
os.path.join(train_dir, f"{year}.bin"),
create_fake_timestamp(year, base_year=1930))
if create_test_data:
self.create_binary_file(ds["test"],
os.path.join(test_dir, f"{year}.bin"),
create_fake_timestamp(year, base_year=1930))

with open(os.path.join(self.data_dir, "overall_stats.json"), "w") as f:
for split, split_dir in split_dirs.items():
self.create_binary_file(
ds[split], split_dir / f"{year}.bin", create_fake_timestamp(year, base_year=1930)
)

with open(self.data_dir / "overall_stats.json", "w") as f:
import json
json.dump(overall_stats, f, indent=4)

if add_final_dummy_year:
dummy_year = year + 1
dummy_data = [ ds["train"][0] ] # get one sample from the previous year
self.create_binary_file(dummy_data,
os.path.join(train_dir, f"{dummy_year}.bin"),
create_fake_timestamp(dummy_year, base_year=1930))
if create_test_data:
self.create_binary_file(dummy_data,
os.path.join(test_dir, f"{dummy_year}.bin"),
create_fake_timestamp(dummy_year, base_year=1930))
for split_dir in split_dirs.values():
self.create_binary_file(
dummy_data, split_dir / f"{dummy_year}.bin", create_fake_timestamp(dummy_year, base_year=1930)
)

@staticmethod
def create_binary_file(data, output_file_name: str, timestamp: int) -> None:
def create_binary_file(data: list[tuple], output_file_name: Path, timestamp: int) -> None:
with open(output_file_name, "wb") as f:
for tensor1, tensor2 in data:
features_bytes = tensor1.numpy().tobytes()
Expand Down

0 comments on commit 70e2a57

Please sign in to comment.