Skip to content

Commit

Permalink
Hotfix/sg 000 add classnames to config (#191)
Browse files Browse the repository at this point in the history
* wip

* add classnames and move data_config to be instantiated earlier
  • Loading branch information
Louis-Dupont authored Sep 26, 2023
1 parent 89f8bde commit 4449ded
Show file tree
Hide file tree
Showing 16 changed files with 122 additions and 306 deletions.
22 changes: 1 addition & 21 deletions src/data_gradients/dataset_adapters/base_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import List, Tuple
from typing import Tuple

import torch

Expand All @@ -24,31 +24,11 @@ def __init__(
dataset_output_mapper: DatasetOutputMapper,
formatter: BatchFormatter,
data_config: DataConfig,
class_names: List[str],
):
self.data_config = data_config

self.dataset_output_mapper = dataset_output_mapper
self.formatter = formatter
self.class_names = class_names

@staticmethod
def resolve_class_names(class_names: List[str], n_classes: int) -> List[str]:
"""Ensure that either `class_names` or `n_classes` is specified, but not both. Return the list of class names that will be used."""
if n_classes and class_names:
raise RuntimeError("`class_names` and `n_classes` cannot be specified at the same time")
elif n_classes is None and class_names is None:
raise RuntimeError("Either `class_names` or `n_classes` must be specified")
return class_names or list(map(str, range(n_classes)))

@staticmethod
def resolve_class_names_to_use(class_names: List[str], class_names_to_use: List[str]) -> List[str]:
"""Define `class_names_to_use` from `class_names` if it is specified. Otherwise, return the list of class names that will be used."""
if class_names_to_use:
invalid_class_names_to_use = set(class_names_to_use) - set(class_names)
if invalid_class_names_to_use != set():
raise RuntimeError(f"You defined `class_names_to_use` with classes that are not listed in `class_names`: {invalid_class_names_to_use}")
return class_names_to_use or class_names

def adapt(self, data: SupportedDataType) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adapt an input data (Batch or Sample) into a standard format.
Expand Down
52 changes: 7 additions & 45 deletions src/data_gradients/dataset_adapters/classification_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from typing import List, Optional, Callable
import torch

from data_gradients.dataset_adapters.config.typing import SupportedDataType
from data_gradients.dataset_adapters.base_adapter import BaseDatasetAdapter
from data_gradients.dataset_adapters.output_mapper.dataset_output_mapper import DatasetOutputMapper
from data_gradients.dataset_adapters.config.data_config import ClassificationDataConfig
Expand All @@ -11,49 +7,15 @@
class ClassificationDatasetAdapter(BaseDatasetAdapter):
"""Wrap a classification dataset so that it would return standardized tensors.
:param cache_path: The filename of the cache file.
:param n_classes: The number of classes.
:param class_names: List of class names.
:param class_names_to_use: List of class names to use.
:param images_extractor: Callable function for extracting images.
:param labels_extractor: Callable function for extracting labels.
:param n_image_channels: Number of image channels.
:param data_config: Instance of DetectionDataConfig class that manages dataset/dataloader configurations.
"""

def __init__(
self,
cache_path: Optional[str] = None,
n_classes: Optional[int] = None,
class_names: Optional[List[str]] = None,
class_names_to_use: Optional[List[str]] = None,
images_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
labels_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
is_batch: Optional[bool] = None,
n_image_channels: int = 3,
data_config: Optional[ClassificationDataConfig] = None,
):
class_names = self.resolve_class_names(class_names=class_names, n_classes=n_classes)
class_names_to_use = self.resolve_class_names_to_use(class_names=class_names, class_names_to_use=class_names_to_use)

if data_config is None:
data_config = ClassificationDataConfig(
cache_path=cache_path,
images_extractor=images_extractor,
labels_extractor=labels_extractor,
is_batch=is_batch,
)

def __init__(self, data_config: ClassificationDataConfig, n_image_channels: int = 3):
dataset_output_mapper = DatasetOutputMapper(data_config=data_config)
formatter = ClassificationBatchFormatter(
data_config=data_config,
class_names=class_names,
class_names_to_use=class_names_to_use,
n_image_channels=n_image_channels,
)
super().__init__(
dataset_output_mapper=dataset_output_mapper,
formatter=formatter,
data_config=data_config,
class_names=class_names,
)
formatter = ClassificationBatchFormatter(data_config=data_config, n_image_channels=n_image_channels)
super().__init__(dataset_output_mapper=dataset_output_mapper, formatter=formatter, data_config=data_config)

@classmethod
def from_cache(cls, cache_path: str) -> "ClassificationDatasetAdapter":
return cls(data_config=ClassificationDataConfig(cache_path=cache_path))
38 changes: 37 additions & 1 deletion src/data_gradients/dataset_adapters/config/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from abc import ABC
from dataclasses import dataclass
from typing import Dict, Optional, Callable, Union
from typing import Dict, Optional, Callable, Union, List

import data_gradients
from data_gradients.dataset_adapters.config.questions import Question, ask_question, text_to_yellow
Expand Down Expand Up @@ -36,6 +36,10 @@ class DataConfig(ABC):
labels_extractor: Union[None, str, Callable[[SupportedDataType], torch.Tensor]] = None
is_batch: Union[None, bool] = None

n_classes: Union[None, int] = None
class_names: Union[None, List[str]] = None
class_names_to_use: Union[None, List[str]] = None

cache_path: Optional[str] = None

def __post_init__(self):
Expand All @@ -45,6 +49,11 @@ def __post_init__(self):
else:
logger.info(f"Cache deactivated for `{self.__class__.__name__}`.")

# Resolve class related params. This should be done after updating from the cache file because the cache may include some of these parameters.
self.class_names = resolve_class_names(class_names=self.class_names, n_classes=self.n_classes)
self.n_classes = len(self.class_names)
self.class_names_to_use = resolve_class_names_to_use(class_names=self.class_names, class_names_to_use=self.class_names_to_use)

def update_from_cache_file(self):
"""Update the values that are not set yet, using the cache file."""
if self.cache_path is not None and os.path.isfile(self.cache_path):
Expand Down Expand Up @@ -106,6 +115,9 @@ def to_json(self) -> JSONDict:
"images_extractor": TensorExtractorResolver.to_string(self.images_extractor),
"labels_extractor": TensorExtractorResolver.to_string(self.labels_extractor),
"is_batch": self.is_batch,
"n_classes": self.n_classes,
"class_names": self.class_names,
"class_names_to_use": self.class_names_to_use,
}
return json_dict

Expand Down Expand Up @@ -135,6 +147,12 @@ def _fill_missing_params(self, json_dict: JSONDict):
self.labels_extractor = json_dict.get("labels_extractor")
if self.is_batch is None:
self.is_batch = json_dict.get("is_batch")
if self.n_classes is None:
self.n_classes = json_dict.get("n_classes")
if self.class_names is None:
self.class_names = json_dict.get("class_names")
if self.class_names_to_use is not None:
self.class_names_to_use = json_dict.get("class_names_to_use")

def get_images_extractor(self, question: Optional[Question] = None, hint: str = "") -> Callable[[SupportedDataType], torch.Tensor]:
if self.images_extractor is None:
Expand Down Expand Up @@ -209,3 +227,21 @@ def get_xyxy_converter(self, hint: str = "") -> Callable[[torch.Tensor], torch.T
)
self.xyxy_converter = ask_question(question=question, hint=hint)
return XYXYConverterResolver.to_callable(self.xyxy_converter)


def resolve_class_names(class_names: List[str], n_classes: int) -> List[str]:
"""Ensure that either `class_names` or `n_classes` is specified, but not both. Return the list of class names that will be used."""
if n_classes and class_names:
raise RuntimeError("`class_names` and `n_classes` cannot be specified at the same time")
elif n_classes is None and class_names is None:
raise RuntimeError("Either `class_names` or `n_classes` must be specified")
return class_names or list(map(str, range(n_classes)))


def resolve_class_names_to_use(class_names: List[str], class_names_to_use: List[str]) -> List[str]:
"""Define `class_names_to_use` from `class_names` if it is specified. Otherwise, return the list of class names that will be used."""
if class_names_to_use:
invalid_class_names_to_use = set(class_names_to_use) - set(class_names)
if invalid_class_names_to_use != set():
raise RuntimeError(f"You defined `class_names_to_use` with classes that are not listed in `class_names`: {invalid_class_names_to_use}")
return class_names_to_use or class_names
60 changes: 8 additions & 52 deletions src/data_gradients/dataset_adapters/detection_adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import List, Optional, Callable
from typing import Optional

import torch

from data_gradients.dataset_adapters.config.typing import SupportedDataType
from data_gradients.dataset_adapters.base_adapter import BaseDatasetAdapter
from data_gradients.dataset_adapters.output_mapper.dataset_output_mapper import DatasetOutputMapper
from data_gradients.dataset_adapters.formatters.detection import DetectionBatchFormatter
Expand All @@ -11,56 +8,15 @@

class DetectionDatasetAdapter(BaseDatasetAdapter):
"""Wrap a detection dataset so that it would return standardized tensors.
:param cache_path: The filename of the cache file.
:param n_classes: The number of classes.
:param class_names: List of class names.
:param class_names_to_use: List of class names to use.
:param images_extractor: Callable function for extracting images.
:param labels_extractor: Callable function for extracting labels.
:param is_label_first: A flag to indicate if labels are the first entity in the dataset.
:param bbox_format: Callable function for formatting bounding boxes.
:param n_image_channels: Number of image channels.
:param data_config: Instance of DetectionDataConfig class that manages dataset/dataloader configurations.
"""

def __init__(
self,
cache_path: Optional[str] = None,
n_classes: Optional[int] = None,
class_names: Optional[List[str]] = None,
class_names_to_use: Optional[List[str]] = None,
images_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
labels_extractor: Optional[Callable[[SupportedDataType], torch.Tensor]] = None,
is_batch: Optional[bool] = None,
is_label_first: Optional[bool] = None,
bbox_format: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
n_image_channels: int = 3,
data_config: Optional[DetectionDataConfig] = None,
):
class_names = self.resolve_class_names(class_names=class_names, n_classes=n_classes)
class_names_to_use = self.resolve_class_names_to_use(class_names=class_names, class_names_to_use=class_names_to_use)

if data_config is None:
data_config = DetectionDataConfig(
cache_path=cache_path,
images_extractor=images_extractor,
labels_extractor=labels_extractor,
is_batch=is_batch,
is_label_first=is_label_first,
xyxy_converter=bbox_format,
)

def __init__(self, data_config: Optional[DetectionDataConfig] = None, n_image_channels: int = 3):
dataset_output_mapper = DatasetOutputMapper(data_config=data_config)
formatter = DetectionBatchFormatter(
data_config=data_config,
class_names=class_names,
class_names_to_use=class_names_to_use,
n_image_channels=n_image_channels,
)
super().__init__(
dataset_output_mapper=dataset_output_mapper,
formatter=formatter,
data_config=data_config,
class_names=class_names,
)
formatter = DetectionBatchFormatter(data_config=data_config, n_image_channels=n_image_channels)
super().__init__(dataset_output_mapper=dataset_output_mapper, formatter=formatter, data_config=data_config)

@classmethod
def from_cache(cls, cache_path: str) -> "DetectionDatasetAdapter":
return cls(data_config=DetectionDataConfig(cache_path=cache_path))
16 changes: 7 additions & 9 deletions src/data_gradients/dataset_adapters/formatters/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Tuple, List
from typing import Tuple

import torch
from torch import Tensor
Expand All @@ -8,6 +8,9 @@
from data_gradients.dataset_adapters.formatters.utils import DatasetFormatError, check_images_shape
from data_gradients.dataset_adapters.formatters.utils import ensure_channel_first
from data_gradients.dataset_adapters.config.data_config import ClassificationDataConfig
from logging import getLogger

logger = getLogger(__name__)


class UnsupportedClassificationBatchFormatError(DatasetFormatError):
Expand All @@ -21,22 +24,17 @@ class ClassificationBatchFormatter(BatchFormatter):
def __init__(
self,
data_config: ClassificationDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int,
):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
:param class_names_to_use: List of class names that we should use for analysis.
:param n_image_channels: Number of image channels (3 for RGB, 1 for Gray Scale, ...)
"""
self.data_config = data_config

class_names_to_use = set(class_names_to_use)
self.class_ids_to_use = [class_id for class_id, class_name in enumerate(class_names) if class_name in class_names_to_use]

self.n_image_channels = n_image_channels

if data_config.class_names_to_use != data_config.class_names:
logger.warning("Classification task does NOT support class filtering, yet `class_names_to_use` was set. This will parameter will be ignored.")

def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
"""Validate batch images and labels format, and ensure that they are in the relevant format for detection.
Expand Down
8 changes: 1 addition & 7 deletions src/data_gradients/dataset_adapters/formatters/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,18 @@ class DetectionBatchFormatter(BatchFormatter):
def __init__(
self,
data_config: DetectionDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int,
xyxy_converter: Optional[Callable[[Tensor], Tensor]] = None,
label_first: Optional[bool] = None,
):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
:param class_names_to_use: List of class names that we should use for analysis.
:param n_image_channels: Number of image channels (3 for RGB, 1 for Gray Scale, ...)
:param xyxy_converter: Function to convert the bboxes to the `xyxy` format.
:param label_first: Whether the annotated_bboxes states with labels, or with the bboxes. (typically label_xyxy vs xyxy_label)
"""
self.data_config = data_config

class_names_to_use = set(class_names_to_use)
self.class_ids_to_use = [class_id for class_id, class_name in enumerate(class_names) if class_name in class_names_to_use]

self.class_ids_to_use = [data_config.class_names.index(class_name) for class_name in data_config.class_names_to_use]
self.n_image_channels = n_image_channels
self.xyxy_converter = xyxy_converter
self.label_first = label_first
Expand Down
18 changes: 6 additions & 12 deletions src/data_gradients/dataset_adapters/formatters/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,17 @@ class SegmentationBatchFormatter(BatchFormatter):
def __init__(
self,
data_config: SegmentationDataConfig,
class_names: List[str],
class_names_to_use: List[str],
n_image_channels: int,
threshold_value: float,
ignore_labels: Optional[List[int]] = None,
):
"""
:param class_names: List of all class names in the dataset. The index should represent the class_id.
:param class_names_to_use: List of class names that we should use for analysis.
:param n_image_channels: Number of image channels (3 for RGB, 1 for Gray Scale, ...)
:param threshold_value: Threshold
:param ignore_labels: Numbers that we should avoid from analyzing as valid classes, such as background
"""
class_names_to_use = set(class_names_to_use)

self.class_names = class_names
self.class_ids_to_ignore = [class_id for class_id, class_name in enumerate(class_names) if class_name not in class_names_to_use]
classes_to_ignore = set(data_config.class_names) - set(data_config.class_names_to_use)
self.class_ids_to_ignore = [data_config.class_names.index(class_name_to_ignore) for class_name_to_ignore in classes_to_ignore]

self.n_image_channels = n_image_channels
self.ignore_labels = ignore_labels or []
Expand Down Expand Up @@ -63,12 +57,12 @@ def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
labels = ensure_channel_first(labels, n_image_channels=self.n_image_channels)

images = check_images_shape(images, n_image_channels=self.n_image_channels)
labels = self.validate_labels_dim(labels, n_classes=len(self.class_names), ignore_labels=self.ignore_labels)

labels = self.ensure_hard_labels(labels, n_classes=len(self.class_names), threshold_value=self.threshold_value)
labels = self.validate_labels_dim(labels, n_classes=self.data_config.n_classes, ignore_labels=self.ignore_labels)
labels = self.ensure_hard_labels(labels, n_classes=self.data_config.n_classes, threshold_value=self.threshold_value)

if self.require_onehot(labels=labels, n_classes=len(self.class_names)):
labels = to_one_hot(labels, n_classes=len(self.class_names))
if self.require_onehot(labels=labels, n_classes=self.data_config.n_classes):
labels = to_one_hot(labels, n_classes=self.data_config.n_classes)

for class_id_to_ignore in self.class_ids_to_ignore:
labels[:, class_id_to_ignore, ...] = 0
Expand Down
Loading

0 comments on commit 4449ded

Please sign in to comment.