Skip to content

Commit

Permalink
1) Working on OI dataset preparation scripts, 2) Simple refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Nov 15, 2018
1 parent ae71514 commit 6808991
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 40 deletions.
20 changes: 0 additions & 20 deletions datasets/oi4bb/prepare.py

This file was deleted.

109 changes: 109 additions & 0 deletions datasets/prep_oi4bb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import argparse
import os
import zipfile
import logging
import pandas as pd

from common.logger_utils import initialize_logging


def parse_args():
parser = argparse.ArgumentParser(
description='Prepare dataset for image classification from Open Images V4 Bounding Boxes',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data-dir',
type=str,
default='../imgclsmob_data/oi4bb',
help='working data directory with source files.')
parser.add_argument(
'--save-dir',
type=str,
default='../imgclsmob_data/oi4bb',
help='directory of destination dataset and log-file')
parser.add_argument(
'--remove-archives',
action='store_true',
help='remove archives.')

parser.add_argument(
'--logging-file-name',
type=str,
default='prepare.log',
help='filename of log')
parser.add_argument(
'--log-packages',
type=str,
default='pandas',
help='list of python packages for logging')
parser.add_argument(
'--log-pip-packages',
type=str,
default='',
help='list of pip packages for logging')
args = parser.parse_args()
return args


def extract_val(src_dir_path,
dst_dir_path,
remove_src,
val_archive_file_name="validation.zip"):
assert (os.path.exists(src_dir_path))
assert (os.path.exists(dst_dir_path))

val_archive_file_path = os.path.join(src_dir_path, val_archive_file_name)
with zipfile.ZipFile(val_archive_file_path) as zf:
zf.extractall(dst_dir_path)
if remove_src:
os.remove(val_archive_file_path)


def create_val_cls_list(src_dir_path,
dst_dir_path,
val_annotation_file_name="validation-annotations-bbox.csv",
val_cls_list_file_name="validation-cls.csv"):
assert (os.path.exists(src_dir_path))
assert (os.path.exists(dst_dir_path))

val_annotation_file_path = os.path.join(src_dir_path, val_annotation_file_name)
val_cls_list_file_path = os.path.join(src_dir_path, val_cls_list_file_name)

df = pd.read_csv(val_annotation_file_path)
df2 = df.assign(Square=(df.XMax - df.XMin) * (df.YMax - df.YMin))
df2 = df2[["ImageID", "LabelName", "Square"]]
df2 = df2.loc[df2.groupby(["ImageID"])["Square"].idxmax()]
df2 = df2[["ImageID", "LabelName"]]
df2.to_csv(val_cls_list_file_path, index=False)


def main():
args = parse_args()

_, log_file_exist = initialize_logging(
logging_dir_path=args.save_dir,
logging_file_name=args.logging_file_name,
script_args=args,
log_packages=args.log_packages,
log_pip_packages=args.log_pip_packages)

src_dir_path = args.data_dir
if not os.path.exists(src_dir_path):
logging.error('Source directory does not exist.')
return
dst_dir_path = args.save_dir
if not os.path.exists(dst_dir_path):
os.makedirs(dst_dir_path)
remove_src = args.remove_archives

extract_val(
src_dir_path=src_dir_path,
dst_dir_path=dst_dir_path,
remove_src=remove_src)
create_val_cls_list(
src_dir_path=src_dir_path,
dst_dir_path=dst_dir_path)


if __name__ == '__main__':
main()
5 changes: 3 additions & 2 deletions eval_ch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a model for image classification (Chainer)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Evaluate a model for image classification (Chainer)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data-dir',
type=str,
Expand Down
5 changes: 3 additions & 2 deletions eval_gl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a model for image classification (Gluon)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Evaluate a model for image classification (Gluon)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data-dir',
type=str,
Expand Down
5 changes: 3 additions & 2 deletions eval_ke.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a model for image classification (Keras)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Evaluate a model for image classification (Keras)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--rec-train',
type=str,
Expand Down
5 changes: 3 additions & 2 deletions eval_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a model for image classification (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Evaluate a model for image classification (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data-dir',
type=str,
Expand Down
5 changes: 3 additions & 2 deletions eval_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a model for image classification (TensorFlow/TensorPack)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Evaluate a model for image classification (TensorFlow/TensorPack)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data-dir',
type=str,
Expand Down
5 changes: 3 additions & 2 deletions train_ch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Train a model for image classification (Chainer)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Train a model for image classification (Chainer)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data-dir',
type=str,
Expand Down
5 changes: 3 additions & 2 deletions train_gl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Train a model for image classification (Gluon)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Train a model for image classification (Gluon)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data-dir',
type=str,
Expand Down
5 changes: 3 additions & 2 deletions train_ke.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Train a model for image classification (Keras)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Train a model for image classification (Keras)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--rec-train',
type=str,
Expand Down
5 changes: 3 additions & 2 deletions train_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Train a model for image classification (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Train a model for image classification (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data-dir',
type=str,
Expand Down
5 changes: 3 additions & 2 deletions train_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Train a model for image classification (TensorFlow/TensorPack)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
description='Train a model for image classification (TensorFlow/TensorPack)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data-dir',
type=str,
Expand Down

0 comments on commit 6808991

Please sign in to comment.