Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ torch>=2.0.0
torchvision==0.15.2
torchtyping
tqdm
accelerate<1.10.0
transformers==4.40.1
pytest
git+https://github.com/openai/CLIP.git
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"torchvision==0.15.2",
"torchtyping",
"tqdm",
"accelerate<1.10.0",
"transformers==4.40.1",
"pytest",
]
Expand Down
31 changes: 30 additions & 1 deletion tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from thingsvision import get_extractor
from thingsvision.utils.data import DataLoader, ImageDataset
from torch.utils.data import Subset
from torch.utils.data import Subset, Dataset

DATA_PATH = "./data"
TEST_PATH = "./test_images"
Expand Down Expand Up @@ -361,6 +361,35 @@ def __len__(self) -> int:
return len(self.values)


class MockImageDataset(Dataset):
"""Mock dataset that returns (image, label) tuples"""

def __init__(self, size=10):
self.size = size

def __len__(self):
return self.size

def __getitem__(self, idx):
image = torch.randn(3, 32, 32)
label = torch.tensor(idx % 5)
return image, label


class MockImageOnlyDataset(Dataset):
"""Mock dataset that returns only images (not tuples)"""

def __init__(self, size=10):
self.size = size

def __len__(self):
return self.size

def __getitem__(self, idx):
# Return only image tensor (no tuple)
return torch.randn(3, 32, 32)


def iterate_through_all_model_combinations():
for model_config in MODEL_AND_MODULE_NAMES.values():
model_name = model_config["model_name"]
Expand Down
14 changes: 11 additions & 3 deletions tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_4D_features(self):
flatten_acts=False,
)
return features

def get_multi_features(self):
model_name = "vgg16_bn"
extractor, _, batches = helper.create_extractor_and_dataloader(
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_storing_4d(self):
)

self.check_file_exists("features", format, False)

def test_storing_multi(self):
features = self.get_multi_features()
for _, feature in features.items():
Expand All @@ -115,6 +115,14 @@ def test_storing_multi(self):
)
self.check_file_exists(f"features", format, False)

def test_extract_multi(self):
features = self.get_multi_features()
row_counts = [feature.shape[0] for feature in features.values()]
self.assertTrue(
all(count == row_counts[0] for count in row_counts),
"Not all features have the same number of rows!",
)

def test_splitting_2d(self):
n_splits = 3
features = self.get_2D_features()
Expand Down Expand Up @@ -154,7 +162,7 @@ def test_splitting_4d(self):
file_format="txt",
n_splits=n_splits,
)

def test_splitting_multi(self):
n_splits = 3
features = self.get_multi_features()
Expand Down
101 changes: 100 additions & 1 deletion tests/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import unittest

import warnings
import numpy as np
import torch
from torch.utils.data import DataLoader

import tests.helper as helper
from thingsvision.core.cka import get_cka
from thingsvision.core.rsa import compute_rdm, correlate_rdms, plot_rdm
from thingsvision.utils.storing import save_features
from thingsvision.core.extraction.torch import ImageOnlyDataloaderModifier


class RSATestCase(unittest.TestCase):
Expand Down Expand Up @@ -114,3 +117,99 @@ def test_filenames(self):
if f.endswith("png"):
img_files.append(os.path.join(root, f))
self.assertEqual(sorted(file_names), sorted(img_files))


class TestImageOnlyDataloaderModifier(unittest.TestCase):

def test_context_manager_with_tuple_format_dataloader(self):
"""
Test 1: Test the context manager with a dataloader that returns (image, label) tuples.
Should replace collate function and extract only images.
"""
dataset = helper.MockImageDataset(size=4)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
modifier = ImageOnlyDataloaderModifier(dataloader)

original_collate_fn = dataloader.collate_fn

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

with modifier as modified_dataloader:
self.assertEqual(len(w), 1)
self.assertIn(
"The dataloader is not in the correct format", str(w[0].message)
)

self.assertNotEqual(modified_dataloader.collate_fn, original_collate_fn)
self.assertEqual(
modified_dataloader.collate_fn, modifier.new_collate_fn
)
self.assertTrue(modifier.should_replace)

batch = next(iter(modified_dataloader))
self.assertIsInstance(batch, torch.Tensor)
self.assertEqual(batch.shape, (2, 3, 32, 32))

self.assertEqual(dataloader.collate_fn, original_collate_fn)
self.assertEqual(modifier.original_collate_fn, original_collate_fn)

def test_context_manager_with_image_only_dataloader(self):
"""
Test 2: Test the context manager with a dataloader that already returns only images.
Should NOT replace collate function since format is already correct.
"""

dataset = helper.MockImageOnlyDataset(size=4)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
modifier = ImageOnlyDataloaderModifier(dataloader)

original_collate_fn = dataloader.collate_fn

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

with modifier as modified_dataloader:
assert len(w) == 0
assert modified_dataloader.collate_fn == original_collate_fn
assert modifier.should_replace is False
assert modifier.original_collate_fn is None

batch = next(iter(modified_dataloader))
assert isinstance(batch, torch.Tensor)
assert batch.shape == (2, 3, 32, 32)

assert dataloader.collate_fn == original_collate_fn

def test_images_only_collate_function(self):
"""
Test 3: Test the static _images_only_collate function directly.
Verify it correctly extracts images from tuples.
"""
mock_batch = [
(torch.randn(3, 32, 32), torch.tensor(0)),
(torch.randn(3, 32, 32), torch.tensor(1)),
(torch.randn(3, 32, 32), torch.tensor(2)),
]

result = ImageOnlyDataloaderModifier._images_only_collate(mock_batch)

assert isinstance(result, torch.Tensor)
assert result.shape == (3, 3, 32, 32)

for i, (original_image, _) in enumerate(mock_batch):
torch.testing.assert_close(result[i], original_image)

def test_check_dataloader_format_method(self):
"""
Bonus Test: Test the _check_dataloader_format method directly.
"""
dataset_tuple = helper.MockImageDataset(size=2)
dataloader_tuple = DataLoader(dataset_tuple, batch_size=1)
modifier_tuple = ImageOnlyDataloaderModifier(dataloader_tuple)
assert modifier_tuple._check_dataloader_format() is True

dataset_image = helper.MockImageOnlyDataset(size=2)
dataloader_image = DataLoader(dataset_image, batch_size=1)
modifier_image = ImageOnlyDataloaderModifier(dataloader_image)
assert modifier_image._check_dataloader_format() is False
84 changes: 50 additions & 34 deletions thingsvision/core/extraction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ def _module_and_output_check(
output_type in self.get_output_types()
), f"\nData type of output feature matrix must be set to one of the following available data types: {self.get_output_types()}\n"

def _save_features(self, features, features_file, extension):
if extension == "npy":
np.save(features_file, features)
elif extension == "pt":
torch.save(features, features_file)
else:
raise ValueError(f"Invalid extension: {extension}")

def extract_features(
self,
batches: Iterator[Union[TensorType["b", "c", "h", "w"], Array]],
Expand Down Expand Up @@ -280,31 +288,35 @@ def extract_features(
features[module_name].append(modules_features[module_name])

if output_dir and (i % step_size == 0 or i == len(batches)):
curr_output_dir = os.path.join(output_dir, module_name)
if not os.path.exists(curr_output_dir):
print(f"Creating output directory: {curr_output_dir}")
os.makedirs(curr_output_dir)

if self.get_backend() == "pt":
features_subset = torch.cat(features[module_name])
if output_type == "ndarray":
features_subset = self._to_numpy(features_subset)
features_subset_file = os.path.join(
output_dir,
f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.npy",
)
np.save(features_subset_file, features_subset)
else: # output_type = tensor
features_subset_file = os.path.join(
output_dir,
f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.pt",
)
torch.save(features_subset, features_subset_file)
file_extension = "npy"
else:
file_extension = "pt"
else:
features_subset_file = os.path.join(
output_dir,
f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.npy",
)
features_subset = np.vstack(features[module_name])
np.save(features_subset_file, features_subset)
features = defaultdict(list)
last_image_ct = image_ct
file_extension = "npy"

features_subset_file = os.path.join(
curr_output_dir,
f"features{file_name_suffix}_{last_image_ct}-{image_ct}.{file_extension}",
)
self._save_features(
features_subset, features_subset_file, file_extension
)

# Note: we add full file paths to feature_file_names to be able to load the features later
feature_file_names[module_name].append(features_subset_file)
features[module_name] = []
last_image_ct = image_ct

print(
f"...Features successfully extracted for all {image_ct} images in the database."
)
Expand All @@ -316,29 +328,31 @@ def extract_features(
features = []
for file in feature_file_names[module_name]:
if self.get_backend() == "pt" and output_type != "ndarray":
if file.endswith(".pt"):
features.append(
torch.load(os.path.join(output_dir, file))
)
features.append(torch.load(file))
elif file.endswith(".npy"):
features.append(np.load(file))
else:
if file.endswith(".npy"):
features.append(
np.load(os.path.join(output_dir, file))
)
raise ValueError(
f"Invalid or unsupported file extension: {file}"
)

features_file = os.path.join(
output_dir, f"{module_name}/features{file_name_suffix}"
)
if output_type == "ndarray":
np.save(f"{features_file}.npy", np.concatenate(features))
else: # output_type = tensor
torch.save(torch.cat(features), f"{features_file}.pt")
self._save_features(
np.concatenate(features), features_file + ".npy", "npy"
)
else:
self._save_features(
torch.cat(features), features_file + ".pt", "pt"
)
print(
f"...Features for module '{module_name}' were saved to {features_file}."
)
# remove temporary files
for file in feature_file_names[module_name]:
os.remove(os.path.join(output_dir, file))
os.remove(file)

print(f"...Features were saved to {output_dir}.")
return None
else:
Expand All @@ -349,9 +363,11 @@ def extract_features(
features[module_name] = self._to_numpy(features[module_name])
else:
features[module_name] = np.vstack(features[module_name])
print(f"...Features shape: {features[module_name].shape}")
print(
f"...Features for module '{module_name}' have shape: {features[module_name].shape}"
)

if single_module_call:
# for backward compatibility
return features[module_name]
return features

Expand Down
Loading
Loading