Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/sg 000 add classnames to config #191

Merged
merged 3 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions src/data_gradients/dataset_adapters/base_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import List, Tuple
from typing import Tuple

import torch

Expand All @@ -24,31 +24,11 @@ def __init__(
dataset_output_mapper: DatasetOutputMapper,
formatter: BatchFormatter,
data_config: DataConfig,
class_names: List[str],
):
self.data_config = data_config

self.dataset_output_mapper = dataset_output_mapper
self.formatter = formatter
self.class_names = class_names

@staticmethod
def resolve_class_names(class_names: List[str], n_classes: int) -> List[str]:
"""Ensure that either `class_names` or `n_classes` is specified, but not both. Return the list of class names that will be used."""
if n_classes and class_names:
raise RuntimeError("`class_names` and `n_classes` cannot be specified at the same time")
elif n_classes is None and class_names is None:
raise RuntimeError("Either `class_names` or `n_classes` must be specified")
return class_names or list(map(str, range(n_classes)))

@staticmethod
def resolve_class_names_to_use(class_names: List[str], class_names_to_use: List[str]) -> List[str]:
"""Define `class_names_to_use` from `class_names` if it is specified. Otherwise, return the list of class names that will be used."""
if class_names_to_use:
invalid_class_names_to_use = set(class_names_to_use) - set(class_names)
if invalid_class_names_to_use != set():
raise RuntimeError(f"You defined `class_names_to_use` with classes that are not listed in `class_names`: {invalid_class_names_to_use}")
return class_names_to_use or class_names

def adapt(self, data: SupportedDataType) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adapt an input data (Batch or Sample) into a standard format.
Expand Down
52 changes: 7 additions & 45 deletions src/data_gradients/dataset_adapters/classification_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from typing import List, Optional, Callable
import torch

from data_gradients.dataset_adapters.config.typing import SupportedDataType
from data_gradients.dataset_adapters.base_adapter import BaseDatasetAdapter
from data_gradients.dataset_adapters.output_mapper.dataset_output_mapper import DatasetOutputMapper
from data_gradients.dataset_adapters.config.data_config import ClassificationDataConfig
Expand All @@ -11,49 +7,15 @@
class ClassificationDatasetAdapter(BaseDatasetAdapter):
"""Wrap a classification dataset so that it would return standardized tensors.

:param cache_path: The filename of the cache file.
:param n_classes: The number of classes.
:param class_names: List of class names.
:param class_names_to_use: List of class names to use.
:param images_extractor: Callable function for extracting images.
:param labels_extractor: Callable function for extracting labels.
:param n_image_channels: Number of image channels.
:param data_config: Instance of DetectionDataConfig class that manages dataset/dataloader configurations.
"""

def __init__(
self,
cache_path: Optional[str] = None,
n_classes: Optional[int] = None,
class_names: Optional[List[str]] = None,
class_names_to_use: Optional[List[str]] = None,
images_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
labels_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
is_batch: Optional[bool] = None,
n_image_channels: int = 3,
data_config: Optional[ClassificationDataConfig] = None,
):
class_names = self.resolve_class_names(class_names=class_names, n_classes=n_classes)
class_names_to_use = self.resolve_class_names_to_use(class_names=class_names, class_names_to_use=class_names_to_use)

if data_config is None:
data_config = ClassificationDataConfig(
cache_path=cache_path,
images_extractor=images_extractor,
labels_extractor=labels_extractor,
is_batch=is_batch,
)

def __init__(self, data_config: ClassificationDataConfig, n_image_channels: int = 3):
dataset_output_mapper = DatasetOutputMapper(data_config=data_config)
formatter = ClassificationBatchFormatter(
data_config=data_config,
class_names=class_names,
class_names_to_use=class_names_to_use,
n_image_channels=n_image_channels,
)
super().__init__(
dataset_output_mapper=dataset_output_mapper,
formatter=formatter,
data_config=data_config,
class_names=class_names,
)
formatter = ClassificationBatchFormatter(data_config=data_config, n_image_channels=n_image_channels)
super().__init__(dataset_output_mapper=dataset_output_mapper, formatter=formatter, data_config=data_config)

@classmethod
def from_cache(cls, cache_path: str) -> "ClassificationDatasetAdapter":
return cls(data_config=ClassificationDataConfig(cache_path=cache_path))
38 changes: 37 additions & 1 deletion src/data_gradients/dataset_adapters/config/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from abc import ABC
from dataclasses import dataclass
from typing import Dict, Optional, Callable, Union
from typing import Dict, Optional, Callable, Union, List

import data_gradients
from data_gradients.dataset_adapters.config.questions import Question, ask_question, text_to_yellow
Expand Down Expand Up @@ -36,6 +36,10 @@ class DataConfig(ABC):
labels_extractor: Union[None, str, Callable[[SupportedDataType], torch.Tensor]] = None
is_batch: Union[None, bool] = None

n_classes: Union[None, int] = None
class_names: Union[None, List[str]] = None
class_names_to_use: Union[None, List[str]] = None

cache_path: Optional[str] = None

def __post_init__(self):
Expand All @@ -45,6 +49,11 @@ def __post_init__(self):
else:
logger.info(f"Cache deactivated for `{self.__class__.__name__}`.")

# Resolve class related params. This should be done after updating from the cache file because the cache may include some of these parameters.
self.class_names = resolve_class_names(class_names=self.class_names, n_classes=self.n_classes)
self.n_classes = len(self.class_names)
self.class_names_to_use = resolve_class_names_to_use(class_names=self.class_names, class_names_to_use=self.class_names_to_use)

def update_from_cache_file(self):
"""Update the values that are not set yet, using the cache file."""
if self.cache_path is not None and os.path.isfile(self.cache_path):
Expand Down Expand Up @@ -106,6 +115,9 @@ def to_json(self) -> JSONDict:
"images_extractor": TensorExtractorResolver.to_string(self.images_extractor),
"labels_extractor": TensorExtractorResolver.to_string(self.labels_extractor),
"is_batch": self.is_batch,
"n_classes": self.n_classes,
"class_names": self.class_names,
"class_names_to_use": self.class_names_to_use,
}
return json_dict

Expand Down Expand Up @@ -135,6 +147,12 @@ def _fill_missing_params(self, json_dict: JSONDict):
self.labels_extractor = json_dict.get("labels_extractor")
if self.is_batch is None:
self.is_batch = json_dict.get("is_batch")
if self.n_classes is None:
self.n_classes = json_dict.get("n_classes")
if self.class_names is None:
self.class_names = json_dict.get("class_names")
if self.class_names_to_use is not None:
self.class_names_to_use = json_dict.get("class_names_to_use")

def get_images_extractor(self, question: Optional[Question] = None, hint: str = "") -> Callable[[SupportedDataType], torch.Tensor]:
if self.images_extractor is None:
Expand Down Expand Up @@ -209,3 +227,21 @@ def get_xyxy_converter(self, hint: str = "") -> Callable[[torch.Tensor], torch.T
)
self.xyxy_converter = ask_question(question=question, hint=hint)
return XYXYConverterResolver.to_callable(self.xyxy_converter)


def resolve_class_names(class_names: List[str], n_classes: int) -> List[str]:
"""Ensure that either `class_names` or `n_classes` is specified, but not both. Return the list of class names that will be used."""
if n_classes and class_names:
raise RuntimeError("`class_names` and `n_classes` cannot be specified at the same time")
elif n_classes is None and class_names is None:
raise RuntimeError("Either `class_names` or `n_classes` must be specified")
return class_names or list(map(str, range(n_classes)))


def resolve_class_names_to_use(class_names: List[str], class_names_to_use: List[str]) -> List[str]:
"""Define `class_names_to_use` from `class_names` if it is specified. Otherwise, return the list of class names that will be used."""
if class_names_to_use:
invalid_class_names_to_use = set(class_names_to_use) - set(class_names)
if invalid_class_names_to_use != set():
raise RuntimeError(f"You defined `class_names_to_use` with classes that are not listed in `class_names`: {invalid_class_names_to_use}")
return class_names_to_use or class_names
60 changes: 8 additions & 52 deletions src/data_gradients/dataset_adapters/detection_adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import List, Optional, Callable
from typing import Optional

import torch

from data_gradients.dataset_adapters.config.typing import SupportedDataType
from data_gradients.dataset_adapters.base_adapter import BaseDatasetAdapter
from data_gradients.dataset_adapters.output_mapper.dataset_output_mapper import DatasetOutputMapper
from data_gradients.dataset_adapters.formatters.detection import DetectionBatchFormatter
Expand All @@ -11,56 +8,15 @@

class DetectionDatasetAdapter(BaseDatasetAdapter):
"""Wrap a detection dataset so that it would return standardized tensors.

:param cache_path: The filename of the cache file.
:param n_classes: The number of classes.
:param class_names: List of class names.
:param class_names_to_use: List of class names to use.
:param images_extractor: Callable function for extracting images.
:param labels_extractor: Callable function for extracting labels.
:param is_label_first: A flag to indicate if labels are the first entity in the dataset.
:param bbox_format: Callable function for formatting bounding boxes.
:param n_image_channels: Number of image channels.
:param data_config: Instance of DetectionDataConfig class that manages dataset/dataloader configurations.
"""

def __init__(
self,
cache_path: Optional[str] = None,
n_classes: Optional[int] = None,
class_names: Optional[List[str]] = None,
class_names_to_use: Optional[List[str]] = None,
images_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
labels_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
is_batch: Optional[bool] = None,
is_label_first: Optional[bool] = None,
bbox_format: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
n_image_channels: int = 3,
data_config: Optional[DetectionDataConfig] = None,
):
class_names = self.resolve_class_names(class_names=class_names, n_classes=n_classes)
class_names_to_use = self.resolve_class_names_to_use(class_names=class_names, class_names_to_use=class_names_to_use)

if data_config is None:
data_config = DetectionDataConfig(
cache_path=cache_path,
images_extractor=images_extractor,
labels_extractor=labels_extractor,
is_batch=is_batch,
is_label_first=is_label_first,
xyxy_converter=bbox_format,
)

def __init__(self, data_config: Optional[DetectionDataConfig] = None, n_image_channels: int = 3):
dataset_output_mapper = DatasetOutputMapper(data_config=data_config)
formatter = DetectionBatchFormatter(
data_config=data_config,
class_names=class_names,
class_names_to_use=class_names_to_use,
n_image_channels=n_image_channels,
)
super().__init__(
dataset_output_mapper=dataset_output_mapper,
formatter=formatter,
data_config=data_config,
class_names=class_names,
)
formatter = DetectionBatchFormatter(data_config=data_config, n_image_channels=n_image_channels)
super().__init__(dataset_output_mapper=dataset_output_mapper, formatter=formatter, data_config=data_config)

@classmethod
def from_cache(cls, cache_path: str) -> "DetectionDatasetAdapter":
return cls(data_config=DetectionDataConfig(cache_path=cache_path))
16 changes: 7 additions & 9 deletions src/data_gradients/dataset_adapters/formatters/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Tuple, List
from typing import Tuple

import torch
from torch import Tensor
Expand All @@ -8,6 +8,9 @@
from data_gradients.dataset_adapters.formatters.utils import DatasetFormatError, check_images_shape
from data_gradients.dataset_adapters.formatters.utils import ensure_channel_first
from data_gradients.dataset_adapters.config.data_config import ClassificationDataConfig
from logging import getLogger

logger = getLogger(__name__)


class UnsupportedClassificationBatchFormatError(DatasetFormatError):
Expand All @@ -21,22 +24,17 @@ class ClassificationBatchFormatter(BatchFormatter):
def __init__(
self,
data_config: ClassificationDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int,
):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
:param class_names_to_use: List of class names that we should use for analysis.
:param n_image_channels: Number of image channels (3 for RGB, 1 for Gray Scale, ...)
"""
self.data_config = data_config

class_names_to_use = set(class_names_to_use)
self.class_ids_to_use = [class_id for class_id, class_name in enumerate(class_names) if class_name in class_names_to_use]

self.n_image_channels = n_image_channels

if data_config.class_names_to_use != data_config.class_names:
logger.warning("Classification task does NOT support class filtering, yet `class_names_to_use` was set. This will parameter will be ignored.")

def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
"""Validate batch images and labels format, and ensure that they are in the relevant format for detection.

Expand Down
8 changes: 1 addition & 7 deletions src/data_gradients/dataset_adapters/formatters/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,18 @@ class DetectionBatchFormatter(BatchFormatter):
def __init__(
self,
data_config: DetectionDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int,
xyxy_converter: Optional[Callable[[Tensor], Tensor]] = None,
label_first: Optional[bool] = None,
):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
:param class_names_to_use: List of class names that we should use for analysis.
:param n_image_channels: Number of image channels (3 for RGB, 1 for Gray Scale, ...)
:param xyxy_converter: Function to convert the bboxes to the `xyxy` format.
:param label_first: Whether the annotated_bboxes states with labels, or with the bboxes. (typically label_xyxy vs xyxy_label)
"""
self.data_config = data_config

class_names_to_use = set(class_names_to_use)
self.class_ids_to_use = [class_id for class_id, class_name in enumerate(class_names) if class_name in class_names_to_use]

self.class_ids_to_use = [data_config.class_names.index(class_name) for class_name in data_config.class_names_to_use]
self.n_image_channels = n_image_channels
self.xyxy_converter = xyxy_converter
self.label_first = label_first
Expand Down
18 changes: 6 additions & 12 deletions src/data_gradients/dataset_adapters/formatters/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,17 @@ class SegmentationBatchFormatter(BatchFormatter):
def __init__(
self,
data_config: SegmentationDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int,
threshold_value: float,
ignore_labels: Optional[List[int]] = None,
):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
:param class_names_to_use: List of class names that we should use for analysis.
:param n_image_channels: Number of image channels (3 for RGB, 1 for Gray Scale, ...)
:param threshold_value: Threshold
:param ignore_labels: Numbers that we should avoid from analyzing as valid classes, such as background
"""
class_names_to_use = set(class_names_to_use)

self.class_names = class_names
self.class_ids_to_ignore = [class_id for class_id, class_name in enumerate(class_names) if class_name not in class_names_to_use]
classes_to_ignore = set(data_config.class_names) - set(data_config.class_names_to_use)
self.class_ids_to_ignore = [data_config.class_names.index(class_name_to_ignore) for class_name_to_ignore in classes_to_ignore]

self.n_image_channels = n_image_channels
self.ignore_labels = ignore_labels or []
Expand Down Expand Up @@ -63,12 +57,12 @@ def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
labels = ensure_channel_first(labels, n_image_channels=self.n_image_channels)

images = check_images_shape(images, n_image_channels=self.n_image_channels)
labels = self.validate_labels_dim(labels, n_classes=len(self.class_names), ignore_labels=self.ignore_labels)

labels = self.ensure_hard_labels(labels, n_classes=len(self.class_names), threshold_value=self.threshold_value)
labels = self.validate_labels_dim(labels, n_classes=self.data_config.n_classes, ignore_labels=self.ignore_labels)
labels = self.ensure_hard_labels(labels, n_classes=self.data_config.n_classes, threshold_value=self.threshold_value)

if self.require_onehot(labels=labels, n_classes=len(self.class_names)):
labels = to_one_hot(labels, n_classes=len(self.class_names))
if self.require_onehot(labels=labels, n_classes=self.data_config.n_classes):
labels = to_one_hot(labels, n_classes=self.data_config.n_classes)

for class_id_to_ignore in self.class_ids_to_ignore:
labels[:, class_id_to_ignore, ...] = 0
Expand Down
Loading