Skip to content

Commit d755e2b

Browse files
author
Gabriele Graffieti
committed
2 parents 6f3f10a + ce4ca87 commit d755e2b

File tree

5 files changed

+396
-152
lines changed

5 files changed

+396
-152
lines changed

avalanche/benchmarks/utils/avalanche_dataset.py

Lines changed: 144 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@
1818
"""
1919
import copy
2020
import warnings
21-
from collections import OrderedDict
21+
from collections import OrderedDict, defaultdict, deque
2222
from enum import Enum, auto
2323

2424
import torch
2525
from torch.utils.data.dataloader import default_collate
26-
from torch.utils.data.dataset import Dataset, Subset
26+
from torch.utils.data.dataset import Dataset, Subset, ConcatDataset
2727
from torchvision.transforms import Compose
2828

2929
from .dataset_utils import manage_advanced_indexing, \
30-
SequenceDataset, ClassificationSubset, LazyTargetsConversion, \
30+
SequenceDataset, ClassificationSubset, \
3131
LazyConcatIntTargets, find_list_from_index, ConstantSequence, \
3232
LazyClassMapping, optimize_sequence, SubSequence, LazyConcatTargets, \
33-
SubsetWithTargets
33+
TupleTLabel
3434
from .dataset_definitions import ITensorDataset, ClassificationDataset, \
3535
IDatasetWithTargets, ISupportedClassificationDataset
3636

@@ -49,7 +49,8 @@
4949
XTransform = Optional[Callable[[Any], Any]]
5050
YTransform = Optional[Callable[[Any], TTargetType]]
5151

52-
SupportedDataset = Union[IDatasetWithTargets, ITensorDataset, Subset]
52+
SupportedDataset = Union[IDatasetWithTargets, ITensorDataset, Subset,
53+
ConcatDataset]
5354

5455

5556
class AvalancheDatasetType(Enum):
@@ -302,8 +303,8 @@ def __radd__(self, other: Dataset) -> 'AvalancheDataset':
302303
return AvalancheConcatDataset([other, self])
303304

304305
def __getitem__(self, idx) -> Union[T_co, Sequence[T_co]]:
305-
return manage_advanced_indexing(
306-
idx, self._get_single_item, len(self), self.collate_fn)
306+
return TupleTLabel(manage_advanced_indexing(
307+
idx, self._get_single_item, len(self), self.collate_fn))
307308

308309
def __len__(self):
309310
return len(self._dataset)
@@ -598,11 +599,10 @@ def _freeze_dataset_group(dataset_copy: TAvalancheDataset,
598599
dataset_copy._freeze_original_dataset(group_name)
599600

600601
def _get_single_item(self, idx: int):
601-
return self._process_pattern(
602-
self._dataset[idx], idx,
603-
isinstance(self._dataset, AvalancheDataset))
602+
return self._process_pattern(self._dataset[idx], idx)
604603

605-
def _process_pattern(self, element: Tuple, idx: int, has_task_label: bool):
604+
def _process_pattern(self, element: Tuple, idx: int):
605+
has_task_label = isinstance(element, TupleTLabel)
606606
if has_task_label:
607607
element = element[:-1]
608608

@@ -611,7 +611,8 @@ def _process_pattern(self, element: Tuple, idx: int, has_task_label: bool):
611611

612612
pattern, label = self._apply_transforms(pattern, label)
613613

614-
return (pattern, label, *element[2:], self.targets_task_labels[idx])
614+
return TupleTLabel((pattern, label, *element[2:],
615+
self.targets_task_labels[idx]))
615616

616617
def _apply_transforms(self, pattern: Any, label: int):
617618
frozen_group = self._frozen_transforms[self.current_transform_group]
@@ -710,6 +711,11 @@ def _initialize_targets_sequence(self, dataset, targets,
710711
if targets is not None:
711712
# User defined targets always take precedence
712713
# Note: no adapter is applied!
714+
if len(targets) != len(dataset):
715+
raise ValueError(
716+
'Invalid amount of target labels. It must be equal to the '
717+
'number of patterns in the dataset. Got {}, expected '
718+
'{}!'.format(len(targets), len(dataset)))
713719
return targets
714720

715721
if targets_adapter is None:
@@ -731,13 +737,7 @@ def _initialize_task_labels_sequence(
731737
'{}!'.format(len(task_labels), len(dataset)))
732738
return task_labels
733739

734-
if hasattr(dataset, 'targets_task_labels'):
735-
# Dataset is probably a dataset of this class
736-
# Suppose that it is
737-
return LazyTargetsConversion(dataset.targets_task_labels, int)
738-
739-
# No task labels found. Set all task labels to 0 (in a lazy way).
740-
return ConstantSequence(0, len(dataset))
740+
return _make_task_labels_from_supported_dataset(dataset)
741741

742742
def _initialize_collate_fn(self, dataset, dataset_type, collate_fn):
743743
if collate_fn is not None:
@@ -952,11 +952,14 @@ def __init__(self,
952952
raise ValueError('class_mapping is defined but the dataset type'
953953
' is neither CLASSIFICATION or UNDEFINED.')
954954

955-
if class_mapping:
955+
if class_mapping is not None:
956956
subset = ClassificationSubset(dataset, indices=indices,
957957
class_mapping=class_mapping)
958+
elif indices is not None:
959+
subset = Subset(dataset, indices=indices)
958960
else:
959-
subset = SubsetWithTargets(dataset, indices=indices)
961+
subset = dataset # Exactly like a plain AvalancheDataset
962+
960963
self._original_dataset = dataset
961964
self._indices = indices
962965

@@ -972,9 +975,7 @@ def __init__(self,
972975
targets_adapter=targets_adapter)
973976

974977
def _get_single_item(self, idx: int):
975-
return self._process_pattern(
976-
self._dataset[idx], idx,
977-
isinstance(self._original_dataset, AvalancheDataset))
978+
return self._process_pattern(self._dataset[idx], idx)
978979

979980
def _initialize_targets_sequence(
980981
self, dataset, targets, dataset_type, targets_adapter) \
@@ -1014,10 +1015,11 @@ def _initialize_task_labels_sequence(
10141015
# case the user just wants to obtain a dataset in which the
10151016
# position of the patterns has been changed according to
10161017
# "indices". This "if" will take care of the corner case, too.
1017-
return LazyClassMapping(task_labels, indices=self._indices)
1018+
return SubSequence(task_labels, indices=self._indices,
1019+
converter=int)
10181020
elif len(task_labels) == len(dataset):
10191021
# task_labels refers to the subset
1020-
return task_labels
1022+
return SubSequence(task_labels, converter=int)
10211023
else:
10221024
raise ValueError(
10231025
'Invalid amount of task labels. It must be equal to the '
@@ -1026,13 +1028,7 @@ def _initialize_task_labels_sequence(
10261028
len(task_labels), len(self._original_dataset),
10271029
len(dataset)))
10281030

1029-
if hasattr(self._original_dataset, 'targets_task_labels'):
1030-
# The original dataset is probably a dataset of this class
1031-
return LazyClassMapping(self._original_dataset.targets_task_labels,
1032-
indices=self._indices)
1033-
1034-
# No task labels found. Set all task labels to 0 (in a lazy way).
1035-
return ConstantSequence(0, len(dataset))
1031+
return super()._initialize_task_labels_sequence(dataset, None)
10361032

10371033

10381034
class AvalancheTensorDataset(AvalancheDataset[T_co, TTargetType]):
@@ -1303,9 +1299,7 @@ def _get_single_item(self, idx: int):
13031299

13041300
single_element = self._dataset_list[dataset_idx][internal_idx]
13051301

1306-
return self._process_pattern(
1307-
single_element, idx,
1308-
isinstance(self._dataset_list[dataset_idx], AvalancheDataset))
1302+
return self._process_pattern(single_element, idx)
13091303

13101304
def _fork_dataset(self: TAvalancheDataset) -> TAvalancheDataset:
13111305
dataset_copy = super()._fork_dataset()
@@ -1353,8 +1347,7 @@ def _initialize_task_labels_sequence(
13531347
concat_t_labels = []
13541348
for dataset_idx, single_dataset in enumerate(self._dataset_list):
13551349
concat_t_labels.append(super()._initialize_task_labels_sequence(
1356-
single_dataset, None
1357-
))
1350+
single_dataset, None))
13581351

13591352
return LazyConcatTargets(concat_t_labels)
13601353

@@ -1595,50 +1588,137 @@ def train_eval_avalanche_datasets(
15951588
return train, test
15961589

15971590

1591+
def _traverse_supported_dataset(
1592+
dataset, values_selector: Callable[[Dataset, List[int]], List],
1593+
indices=None) -> List:
1594+
1595+
initial_error = None
1596+
try:
1597+
result = values_selector(dataset, indices)
1598+
if result is not None:
1599+
return result
1600+
except BaseException as e:
1601+
initial_error = e
1602+
1603+
if isinstance(dataset, Subset):
1604+
if indices is None:
1605+
indices = range(len(dataset))
1606+
indices = [dataset.indices[x] for x in indices]
1607+
return list(_traverse_supported_dataset(
1608+
dataset.dataset, values_selector, indices))
1609+
1610+
if isinstance(dataset, ConcatDataset):
1611+
result = []
1612+
if indices is None:
1613+
for c_dataset in dataset.datasets:
1614+
result += list(_traverse_supported_dataset(
1615+
c_dataset, values_selector, indices))
1616+
return result
1617+
1618+
datasets_to_indexes = defaultdict(list)
1619+
indexes_to_dataset = []
1620+
datasets_len = []
1621+
recursion_result = []
1622+
1623+
all_size = 0
1624+
for c_dataset in dataset.datasets:
1625+
len_dataset = len(c_dataset)
1626+
datasets_len.append(len_dataset)
1627+
all_size += len_dataset
1628+
1629+
for subset_idx in indices:
1630+
dataset_idx, pattern_idx = \
1631+
find_list_from_index(subset_idx, datasets_len, all_size)
1632+
datasets_to_indexes[dataset_idx].append(pattern_idx)
1633+
indexes_to_dataset.append(dataset_idx)
1634+
1635+
for dataset_idx, c_dataset in enumerate(dataset.datasets):
1636+
recursion_result.append(deque(_traverse_supported_dataset(
1637+
c_dataset, values_selector, datasets_to_indexes[dataset_idx])))
1638+
1639+
result = []
1640+
for idx in range(len(indices)):
1641+
dataset_idx = indexes_to_dataset[idx]
1642+
result.append(recursion_result[dataset_idx].popleft())
1643+
1644+
return result
1645+
1646+
if initial_error is not None:
1647+
raise initial_error
1648+
1649+
raise ValueError('Error: can\'t find the needed data in the given dataset')
1650+
1651+
1652+
def _select_targets(dataset, indices):
1653+
if hasattr(dataset, 'targets'):
1654+
# Standard supported dataset
1655+
found_targets = dataset.targets
1656+
elif hasattr(dataset, 'tensors'):
1657+
# Support for PyTorch TensorDataset
1658+
if len(dataset.tensors) < 2:
1659+
raise ValueError('Tensor dataset has not enough tensors: '
1660+
'at least 2 are required.')
1661+
found_targets = dataset.tensors[1]
1662+
else:
1663+
raise ValueError(
1664+
'Unsupported dataset: must have a valid targets field '
1665+
'or has to be a Tensor Dataset with at least 2 '
1666+
'Tensors')
1667+
1668+
if indices is not None:
1669+
found_targets = SubSequence(found_targets, indices=indices)
1670+
1671+
return found_targets
1672+
1673+
1674+
def _select_task_labels(dataset, indices):
1675+
found_task_labels = None
1676+
if hasattr(dataset, 'targets_task_labels'):
1677+
found_task_labels = dataset.targets_task_labels
1678+
1679+
if found_task_labels is None:
1680+
if isinstance(dataset, (Subset, ConcatDataset)):
1681+
return None # Continue traversing
1682+
1683+
if found_task_labels is None:
1684+
if indices is None:
1685+
return ConstantSequence(0, len(dataset))
1686+
return ConstantSequence(0, len(indices))
1687+
1688+
if indices is not None:
1689+
found_task_labels = SubSequence(found_task_labels, indices=indices)
1690+
1691+
return found_task_labels
1692+
1693+
15981694
def _make_target_from_supported_dataset(
15991695
dataset: SupportedDataset,
16001696
converter: Callable[[Any], TTargetType] = None) -> \
16011697
Sequence[TTargetType]:
16021698
if isinstance(dataset, AvalancheDataset):
16031699
if converter is None:
16041700
return dataset.targets
1605-
elif isinstance(dataset.targets, LazyTargetsConversion) and \
1701+
elif isinstance(dataset.targets,
1702+
(SubSequence, LazyConcatTargets)) and \
16061703
dataset.targets.converter == converter:
16071704
return dataset.targets
16081705
elif isinstance(dataset.targets, LazyClassMapping) and converter == int:
16091706
# LazyClassMapping already outputs int targets
16101707
return dataset.targets
16111708

1612-
# Support for PyTorch "Subset"
1613-
subset_indices = False
1614-
indices = range(len(dataset))
1615-
while isinstance(dataset, Subset):
1616-
subset_indices = True
1617-
indices = [dataset.indices[x] for x in indices]
1618-
dataset = dataset.dataset
1709+
targets = _traverse_supported_dataset(dataset, _select_targets)
16191710

1620-
if hasattr(dataset, 'targets'):
1621-
# Standard supported dataset
1622-
found_targets = dataset.targets
1623-
elif hasattr(dataset, 'tensors'):
1624-
# Support for PyTorch TensorDataset
1625-
if len(dataset.tensors) < 2:
1626-
raise ValueError('Tensor dataset has not enough tensors: '
1627-
'at least 2 are required.')
1628-
found_targets = dataset.tensors[1]
1629-
else:
1630-
raise ValueError('Unsupported dataset: must have a valid targets field '
1631-
'or has to be a Tensor Dataset with at least 2 '
1632-
'Tensors')
1711+
return SubSequence(targets, converter=converter)
16331712

1634-
if subset_indices:
1635-
found_targets = SubSequence(found_targets, indices)
16361713

1637-
return LazyTargetsConversion(found_targets, converter=converter)
1714+
def _make_task_labels_from_supported_dataset(dataset: SupportedDataset) -> \
1715+
Sequence[int]:
1716+
if isinstance(dataset, AvalancheDataset):
1717+
return dataset.targets_task_labels
16381718

1719+
task_labels = _traverse_supported_dataset(dataset, _select_task_labels)
16391720

1640-
def _is_tensor_dataset(dataset: SupportedDataset) -> bool:
1641-
return hasattr(dataset, 'tensors') and len(dataset.tensors) >= 2
1721+
return SubSequence(task_labels, converter=int)
16421722

16431723

16441724
__all__ = [

0 commit comments

Comments
 (0)