Skip to content

Commit

Permalink
__getitem__ of FGNet correction and test improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
franberchez committed Jul 23, 2024
1 parent e808664 commit d8fe4c6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 53 deletions.
2 changes: 2 additions & 0 deletions dlordinal/datasets/fgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
image = self.transform(image)

target = int(self.data.iloc[index]["category"])
if self.target_transform:
target = self.target_transform(target)

return image, target

Expand Down
130 changes: 77 additions & 53 deletions dlordinal/datasets/tests/test_fgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,107 +3,131 @@

import numpy as np
import pytest
import torch
from PIL import Image
from torchvision.transforms import ToTensor

from dlordinal.datasets import FGNet

TMP_DIR = "./tmp_test_dir_fgnet"


@pytest.fixture
def fgnet_instance():
def fgnet_train():
root = TMP_DIR
fgnet = FGNet(
root,
download=True,
# transform=Compose([ToTensor()]),
target_transform=np.array,
train=True,
)
return fgnet


def test_download(fgnet_instance):
fgnet_instance.download()
assert fgnet_instance._check_integrity_download()
@pytest.fixture
def fgnet_test():
root = TMP_DIR
fgnet = FGNet(
root,
download=True,
train=False,
)
return fgnet


def test_process(fgnet_instance):
fgnet_instance.process(
fgnet_instance.root / "FGNET/images",
fgnet_instance.root / "FGNET/data_processed",
def test_download(fgnet_train):
fgnet_train.download()
assert fgnet_train._check_integrity_download()


def test_process(fgnet_train):
fgnet_train.process(
fgnet_train.root / "FGNET/images",
fgnet_train.root / "FGNET/data_processed",
)
assert fgnet_instance._check_integrity_process()
assert fgnet_train._check_integrity_process()


def test_split(fgnet_instance):
fgnet_instance.split(
fgnet_instance.root / "FGNET/data_processed/fgnet.csv",
fgnet_instance.root / "FGNET/data_processed/train.csv",
fgnet_instance.root / "FGNET/data_processed/test.csv",
fgnet_instance.root / "FGNET/data_processed",
fgnet_instance.root / "FGNET/train",
fgnet_instance.root / "FGNET/test",
def test_split(fgnet_train):
fgnet_train.split(
fgnet_train.root / "FGNET/data_processed/fgnet.csv",
fgnet_train.root / "FGNET/data_processed/train.csv",
fgnet_train.root / "FGNET/data_processed/test.csv",
fgnet_train.root / "FGNET/data_processed",
fgnet_train.root / "FGNET/train",
fgnet_train.root / "FGNET/test",
)
assert fgnet_instance._check_integrity_split()
assert fgnet_train._check_integrity_split()


def test_find_category(fgnet_instance):
assert fgnet_instance.find_category(1) == 0
assert fgnet_instance.find_category(9) == 1
assert fgnet_instance.find_category(14) == 2
assert fgnet_instance.find_category(21) == 3
assert fgnet_instance.find_category(33) == 4
def test_find_category(fgnet_train):
assert fgnet_train.find_category(1) == 0
assert fgnet_train.find_category(9) == 1
assert fgnet_train.find_category(14) == 2
assert fgnet_train.find_category(21) == 3
assert fgnet_train.find_category(33) == 4


def test_get_age_from_filename(fgnet_instance):
def test_get_age_from_filename(fgnet_train):
filename = "001A12X_X.jpg"
assert fgnet_instance.get_age_from_filename(filename) == 12
assert fgnet_train.get_age_from_filename(filename) == 12


def test_load_data(fgnet_instance):
data = fgnet_instance.load_data(fgnet_instance.root / "FGNET/images")
def test_load_data(fgnet_train):
data = fgnet_train.load_data(fgnet_train.root / "FGNET/images")
assert len(data) > 0


def test_process_images_from_df(fgnet_instance):
data = fgnet_instance.load_data(fgnet_instance.root / "FGNET/images")
processed_images = list(
(fgnet_instance.root / "FGNET/data_processed").rglob("*.JPG")
)
def test_process_images_from_df(fgnet_train):
data = fgnet_train.load_data(fgnet_train.root / "FGNET/images")
processed_images = list((fgnet_train.root / "FGNET/data_processed").rglob("*.JPG"))
assert len(processed_images) == len(data)


def test_split_dataframe(fgnet_instance):
csv_path = fgnet_instance.root / "FGNET/data_processed/fgnet.csv"
train_images_path = fgnet_instance.root / "FGNET/train"
original_images_path = fgnet_instance.root / "FGNET/images"
test_images_path = fgnet_instance.root / "FGNET/test"
train_df, test_df = fgnet_instance.split_dataframe(
def test_split_dataframe(fgnet_train):
csv_path = fgnet_train.root / "FGNET/data_processed/fgnet.csv"
train_images_path = fgnet_train.root / "FGNET/train"
original_images_path = fgnet_train.root / "FGNET/images"
test_images_path = fgnet_train.root / "FGNET/test"
train_df, test_df = fgnet_train.split_dataframe(
csv_path, train_images_path, original_images_path, test_images_path
)
assert len(train_df) > 0
assert len(test_df) > 0


def test_getitem(fgnet_instance):
img, label = fgnet_instance[0]
assert isinstance(img, Image.Image)
assert img.mode == "RGB"
assert img.size == (128, 128)
assert label in range(6)
def test_getitem(fgnet_train, fgnet_test):
for fgnet in [fgnet_train, fgnet_test]:
for i in range(len(fgnet)):
assert isinstance(fgnet[i][0], Image.Image)
assert isinstance(fgnet[i][1], int)
assert fgnet[i][1] == fgnet.targets[i]
assert np.array(fgnet[i][0]).ndim == 3

fgnet.transform = ToTensor()

for i in range(len(fgnet)):
assert isinstance(fgnet[i][0], torch.Tensor)
assert isinstance(fgnet[i][1], int)
assert fgnet[i][1] == fgnet.targets[i]
assert len(fgnet[i][0].shape) == 3

fgnet.target_transform = lambda target: np.array(target)
for i in range(len(fgnet)):
assert isinstance(fgnet[i][0], torch.Tensor)
assert isinstance(fgnet[i][1], np.ndarray)
assert np.array_equal(fgnet[i][1], fgnet.targets[i])


def test_len(fgnet_instance):
assert len(fgnet_instance) > 0
def test_len(fgnet_train):
assert len(fgnet_train) > 0


def test_targets(fgnet_instance):
assert len(fgnet_instance.targets) > 0
def test_targets(fgnet_train):
assert len(fgnet_train.targets) > 0


def test_classes(fgnet_instance):
assert len(fgnet_instance.classes) == 6
def test_classes(fgnet_train):
assert len(fgnet_train.classes) == 6


def test_clean_up():
Expand Down

0 comments on commit d8fe4c6

Please sign in to comment.