Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

50 train with samples #58

Merged
merged 2 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions alodataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,11 @@ def __init__(
super(BaseDataset, self).__init__(**kwargs)
self.name = name
self.sample = sample
if not self.sample:
self.items = []
self.dataset_dir = self.get_dataset_dir()
self.dataset_dir = self.get_dataset_dir()
if self.sample:
self.items = self.download_sample()
self.dataset_dir = os.path.join(self.vb_folder, "samples")
else:
self.items = []
self.transform_fn = transform_fn
self.ignore_errors = ignore_errors
self.print_errors = print_errors
Expand Down Expand Up @@ -212,6 +211,9 @@ def get_dataset_dir(self) -> str:
"""Look for dataset_dir based on the given name. To work properly a alodataset_config.json
file must be save into /home/USER/.aloception/alodataset_config.json
"""
if self.sample:
return os.path.join(self.vb_folder, "samples")

streaming_dt_config = os.path.join(self.vb_folder, "alodataset_config.json")
if not os.path.exists(streaming_dt_config):
self.set_dataset_dir(None)
Expand Down Expand Up @@ -247,19 +249,19 @@ def set_dataset_dir(self, dataset_dir: str):

if dataset_dir is None:
dataset_dir = _user_prompt(
f"{self.name} does not exist in config file."
f"{self.name} does not exist in config file. "
+ "Do you want to download and use a sample?: (Y)es or (N)o: "
)
if dataset_dir.lower() in ["y", "yes"]:
if dataset_dir.lower() in ["y", "yes"]: # Download sample and change root directory
self.sample = True
return
return os.path.join(self.vb_folder, "samples")
dataset_dir = _user_prompt(f"Please write a new root directory for {self.name} dataset: ")
dataset_dir = os.path.expanduser(dataset_dir)

# Save the config
if not os.path.exists(dataset_dir):
dataset_dir = _user_prompt(
f"[WARNING] {dataset_dir} path does not exists for dataset: {self.name}."
f"[WARNING] {dataset_dir} path does not exists for dataset: {self.name}. "
+ "Please write a new directory:"
)
dataset_dir = os.path.expanduser(dataset_dir)
Expand Down
5 changes: 3 additions & 2 deletions alodataset/coco_detection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
if "sample" not in kwargs:
kwargs["sample"] = False

if not kwargs["sample"]:
self.sample = kwargs["sample"]
if not self.sample:
assert img_folder is not None, "When sample = False, img_folder must be given."
assert ann_file is not None, "When sample = False, ann_file must be given."

Expand All @@ -91,8 +92,8 @@ def __init__(
img_folder = os.path.join(dataset_dir, img_folder)
ann_file = os.path.join(dataset_dir, ann_file)
stuff_ann_file = None if stuff_ann_file is None else os.path.join(dataset_dir, stuff_ann_file)
kwargs["sample"] = self.sample

self.sample = kwargs["sample"]
super(CocoDetectionDataset, self).__init__(name=name, root=img_folder, annFile=ann_file, **kwargs)
if self.sample:
return
Expand Down
3 changes: 0 additions & 3 deletions alonet/deformable_detr/train_on_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ def get_arg_parser():
parser = ArgumentParser(conflict_handler="resolve")
parser = alonet.common.add_argparse_args(parser) # Common alonet parser
parser = CocoDetection2Detr.add_argparse_args(parser) # Coco detection parser
parser.add_argument(
"--use_sample", action="store_true", help="Download a sample for train process (Default: %(default)s)"
)
parser = LitDeformableDetr.add_argparse_args(parser) # LitDeformableDetr training parser
# parser = pl.Trainer.add_argparse_args(parser) # Pytorch lightning Parser
return parser
Expand Down
34 changes: 19 additions & 15 deletions alonet/detr/coco_data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(
train_ann: str = "annotations/instances_train2017.json",
val_folder: str = "val2017",
val_ann: str = "annotations/instances_val2017.json",
sample: bool = False,
**kwargs
):
"""LightningDataModule to use coco dataset in Detr models
Expand Down Expand Up @@ -58,6 +57,7 @@ def __init__(
Arguments entered by the user (kwargs) will replace those stored in args attribute
"""
# Update class attributes with args and kwargs inputs

super().__init__()
alonet.common.pl_helpers.params_update(self, args, kwargs)

Expand All @@ -71,17 +71,16 @@ def __init__(
# Split=Split.TRAIN if not self.train_on_val else Split.VAL,
classes=classes,
name=name,
sample=sample,
)
self.val_loader_kwargs = dict(
img_folder=val_folder,
ann_file=val_ann,
# split=Split.VAL,
classes=classes,
name=name,
sample=sample,
)
self.args = args
self.val_check() # Check val loader and set some previous parameters

@staticmethod
def add_argparse_args(parent_parser):
Expand All @@ -104,18 +103,11 @@ def add_argparse_args(parent_parser):
nargs="+",
help="If no augmentation (--no_augmentation) is used, --size can be used to resize all the frame.",
)
# parser.add_argument("--classes", type=str, default=None, nargs="+", help="List to classes to be filtered in dataset. (%(default)s by default)")
parser.add_argument(
"--sample", action="store_true", help="Download a sample for train/val process (Default: %(default)s)"
)
return parent_parser

@property
def CATEGORIES(self):
if not hasattr(self, "coco_train"):
self.setup()
if not hasattr(self, "coco_train"):
return None
else:
return self.coco_train.CATEGORIES if hasattr(self.coco_train, "CATEGORIES") else None

def train_transform(self, frame, same_on_sequence: bool = True, same_on_frames: bool = False):
if self.no_augmentation:
if self.size[0] is not None and self.size[1] is not None:
Expand Down Expand Up @@ -158,15 +150,27 @@ def val_transform(

return frame.norm_resnet()

def val_check(self):
# Instance a default loader to set attributes
self.coco_val = alodataset.CocoDetectionDataset(
transform_fn=self.val_transform, sample=self.sample, **self.val_loader_kwargs,
)
self.sample = self.coco_val.sample or self.sample # Update sample if user prompt is given
self.CATEGORIES = self.coco_val.CATEGORIES if hasattr(self.coco_val, "CATEGORIES") else None

def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
# Setup train/val loaders
self.coco_train = alodataset.CocoDetectionDataset(
transform_fn=self.train_transform, **self.train_loader_kwargs
transform_fn=self.train_transform, sample=self.sample, **self.train_loader_kwargs
)
self.coco_val = alodataset.CocoDetectionDataset(
transform_fn=self.val_transform, sample=self.sample, **self.val_loader_kwargs
)
self.coco_val = alodataset.CocoDetectionDataset(transform_fn=self.val_transform, **self.val_loader_kwargs)

def train_dataloader(self):
"""Train dataloader"""
# Init training loader
if not hasattr(self, "coco_train"):
self.setup()
return self.coco_train.train_loader(batch_size=self.batch_size, num_workers=self.num_workers)
Expand Down
5 changes: 1 addition & 4 deletions alonet/detr/train_on_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ def get_arg_parser():
parser = ArgumentParser(conflict_handler="resolve")
parser = alonet.common.add_argparse_args(parser) # Common alonet parser
parser = CocoDetection2Detr.add_argparse_args(parser) # Coco detection parser
parser.add_argument(
"--use_sample", action="store_true", help="Download a sample for train process (Default: %(default)s)"
)
parser = LitDetr.add_argparse_args(parser) # LitDetr training parser
# parser = pl.Trainer.add_argparse_args(parser) # Pytorch lightning Parser
return parser
Expand All @@ -24,7 +21,7 @@ def main():

# Init the Detr model with the dataset
detr = LitDetr(args)
coco_loader = CocoDetection2Detr(args, sample=args.use_sample)
coco_loader = CocoDetection2Detr(args)

detr.run_train(data_loader=coco_loader, args=args, project="detr", expe_name="detr_50")

Expand Down
3 changes: 2 additions & 1 deletion alonet/raft/data_modules/chairs2raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def __init__(self, args):
def train_dataloader(self):
split = Split.VAL if self.train_on_val else Split.TRAIN
dataset = FlyingChairs2Dataset(split=split, transform_fn=self.train_transform, sample=self.sample)
self.sample = self.sample or dataset.sample
sampler = SequentialSampler if self.sequential else RandomSampler
return dataset.train_loader(batch_size=self.batch_size, num_workers=self.num_workers, sampler=sampler)

def val_dataloader(self):
dataset = FlyingChairs2Dataset(split=Split.VAL, transform_fn=self.val_transform, sample=self.sample)

self.sample = self.sample or dataset.sample
return dataset.train_loader(batch_size=1, num_workers=self.num_workers, sampler=SequentialSampler)


Expand Down
3 changes: 3 additions & 0 deletions alonet/raft/data_modules/data2raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def add_argparse_args(parent_parser):
parser.add_argument("--num_workers", type=int, default=8, help="num_workers to use on the dataset")
parser.add_argument("--limit_val_batches", type=_int_or_float_type, default=100)
parser.add_argument("--sequential_sampler", action="store_true", help="sample data sequentially (no shuffle)")
parser.add_argument(
"--sample", action="store_true", help="Download a sample for train/val process (Default: %(default)s)"
)
return parent_parser

def train_transform(self, frame):
Expand Down
3 changes: 0 additions & 3 deletions alonet/raft/train_on_chairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ def get_args_parser():
parser = argparse.ArgumentParser(conflict_handler="resolve")
parser = alonet.common.add_argparse_args(parser, add_pl_args=True)
parser = Chairs2RAFT.add_argparse_args(parser)
parser.add_argument(
"--use_sample", action="store_true", help="Download a sample for train process (Default: %(default)s)"
)
parser = LitRAFT.add_argparse_args(parser)
return parser

Expand Down
6 changes: 3 additions & 3 deletions unittest/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_argparse_defaults(parser):
detr_args["weights"] = "detr-r50"
detr_args["train_on_val"] = True
detr_args["fast_dev_run"] = True
detr_args["use_sample"] = True
detr_args["sample"] = True


@mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**detr_args))
Expand All @@ -38,7 +38,7 @@ def test_detr(mock_args):
def_detr_args["model_name"] = "deformable-detr-r50"
def_detr_args["train_on_val"] = True
def_detr_args["fast_dev_run"] = True
def_detr_args["use_sample"] = True
def_detr_args["sample"] = True


@mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**def_detr_args))
Expand All @@ -54,7 +54,7 @@ def test_deformable_detr(mock_args):
raft_args["weights"] = "raft-things"
raft_args["train_on_val"] = True
raft_args["fast_dev_run"] = True
raft_args["use_sample"] = True
raft_args["sample"] = True


@mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**raft_args))
Expand Down