Skip to content

Commit

Permalink
Cleanup data pipeline (#103)
Browse files Browse the repository at this point in the history
* Remove DataPipeline

* Fix prediction with image path

* Minor fixes

* Add unittest for contains_any_tensor
  • Loading branch information
zhiqwang authored May 1, 2021
1 parent 0dfc40f commit 13d150e
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 264 deletions.
13 changes: 11 additions & 2 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import unittest
import numpy as np

import torch
from torch import Tensor

from yolort.data import DetectionDataModule
import yolort.data._helper as data_helper
from yolort.data import DetectionDataModule, contains_any_tensor, _helper as data_helper

from typing import Dict


class DataPipelineTester(unittest.TestCase):
def test_contains_any_tensor(self):
dummy_numpy = np.random.randn(3, 6)
self.assertFalse(contains_any_tensor(dummy_numpy))
dummy_tensor = torch.randn(3, 6)
self.assertTrue(contains_any_tensor(dummy_tensor))
dummy_tensors = [torch.randn(3, 6), torch.randn(9, 5)]
self.assertTrue(contains_any_tensor(dummy_tensors))

def test_get_dataset(self):
# Acquire the images and labels from the coco128 dataset
train_dataset = data_helper.get_dataset(data_root='data-bin', mode='train')
Expand Down
3 changes: 1 addition & 2 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

import pytorch_lightning as pl

from yolort.data import COCOEvaluator, DetectionDataModule
import yolort.data._helper as data_helper
from yolort.data import COCOEvaluator, DetectionDataModule, _helper as data_helper

from yolort.models import yolov5s
from yolort.models.yolo import yolov5_darknet_pan_s_r31
Expand Down
4 changes: 2 additions & 2 deletions yolort/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from .coco_eval import COCOEvaluator
from .data_pipeline import DataPipeline
from .data_module import DetectionDataModule, VOCDetectionDataModule, COCODetectionDataModule
from .coco_eval import COCOEvaluator
from ._helper import contains_any_tensor
22 changes: 19 additions & 3 deletions yolort/data/_helper.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import random
from pathlib import Path, PosixPath
from zipfile import ZipFile

import torch
from torchvision import ops
from torch import Tensor

from typing import Type, Any

from .coco import COCODetection
from .transforms import collate_fn, default_train_transforms, default_val_transforms

import logging
logger = logging.getLogger(__name__)


def get_coco_api_from_dataset(dataset):
for _ in range(10):
Expand All @@ -24,6 +25,19 @@ def get_coco_api_from_dataset(dataset):
raise NotImplementedError("Currently only supports COCO datasets")


def contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool:
"""
Determine whether or not a list contains any Type
"""
if isinstance(value, dtype):
return True
if isinstance(value, (list, tuple)):
return any(contains_any_tensor(v, dtype=dtype) for v in value)
elif isinstance(value, dict):
return any(contains_any_tensor(v, dtype=dtype) for v in value.values())
return False


def prepare_coco128(
data_path: PosixPath,
dirname: str = 'coco128',
Expand All @@ -35,6 +49,8 @@ def prepare_coco128(
data_path (PosixPath): root path of coco128 dataset.
dirname (str): the directory name of coco128 dataset. Default: 'coco128'.
"""
logger = logging.getLogger(__name__)

if not data_path.is_dir():
logger.info(f'Create a new directory: {data_path}')
data_path.mkdir(parents=True, exist_ok=True)
Expand Down
68 changes: 26 additions & 42 deletions yolort/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@

from pytorch_lightning import LightningDataModule

from typing import Callable, List, Any, Optional

from .transforms import collate_fn, default_train_transforms, default_val_transforms
from .voc import VOCDetection
from .coco import COCODetection
from .data_pipeline import DataPipeline
from .detection_pipeline import ObjectDetectionDataPipeline

from typing import Callable, List, Any, Optional


class DetectionDataModule(LightningDataModule):
Expand Down Expand Up @@ -79,19 +77,33 @@ def val_dataloader(self, batch_size: int = 16) -> None:

return loader

@property
def data_pipeline(self) -> DataPipeline:
if self._data_pipeline is None:
self._data_pipeline = self.default_pipeline()
return self._data_pipeline

@data_pipeline.setter
def data_pipeline(self, data_pipeline) -> None:
self._data_pipeline = data_pipeline
class COCODetectionDataModule(DetectionDataModule):
def __init__(
self,
data_path: str,
year: str = "2017",
train_transform: Optional[Callable] = default_train_transforms,
val_transform: Optional[Callable] = default_val_transforms,
batch_size: int = 1,
num_workers: int = 0,
*args: Any,
**kwargs: Any,
) -> None:
train_dataset = self.build_datasets(
data_path, image_set='train', year=year, transforms=train_transform)
val_dataset = self.build_datasets(
data_path, image_set='val', year=year, transforms=val_transform)

super().__init__(train_dataset=train_dataset, val_dataset=val_dataset,
batch_size=batch_size, num_workers=num_workers, *args, **kwargs)

self.num_classes = 80

@staticmethod
def default_pipeline() -> DataPipeline:
return ObjectDetectionDataPipeline()
def build_datasets(data_path, image_set, year, transforms):
ann_file = Path(data_path) / 'annotations' / f"instances_{image_set}{year}.json"
return COCODetection(data_path, ann_file, transforms())


class VOCDetectionDataModule(DetectionDataModule):
Expand Down Expand Up @@ -134,31 +146,3 @@ def build_datasets(data_path, image_set, years, transforms):
return datasets[0], num_classes
else:
return torch.utils.data.ConcatDataset(datasets), num_classes


class COCODetectionDataModule(DetectionDataModule):
def __init__(
self,
data_path: str,
year: str = "2017",
train_transform: Optional[Callable] = default_train_transforms,
val_transform: Optional[Callable] = default_val_transforms,
batch_size: int = 1,
num_workers: int = 0,
*args: Any,
**kwargs: Any,
) -> None:
train_dataset = self.build_datasets(
data_path, image_set='train', year=year, transforms=train_transform)
val_dataset = self.build_datasets(
data_path, image_set='val', year=year, transforms=val_transform)

super().__init__(train_dataset=train_dataset, val_dataset=val_dataset,
batch_size=batch_size, num_workers=num_workers, *args, **kwargs)

self.num_classes = 80

@staticmethod
def build_datasets(data_path, image_set, year, transforms):
ann_file = Path(data_path).joinpath('annotations').joinpath(f"instances_{image_set}{year}.json")
return COCODetection(data_path, ann_file, transforms())
92 changes: 0 additions & 92 deletions yolort/data/data_pipeline.py

This file was deleted.

69 changes: 0 additions & 69 deletions yolort/data/detection_pipeline.py

This file was deleted.

Loading

0 comments on commit 13d150e

Please sign in to comment.