1818"""
1919import copy
2020import warnings
21- from collections import OrderedDict
21+ from collections import OrderedDict , defaultdict , deque
2222from enum import Enum , auto
2323
2424import torch
2525from torch .utils .data .dataloader import default_collate
26- from torch .utils .data .dataset import Dataset , Subset
26+ from torch .utils .data .dataset import Dataset , Subset , ConcatDataset
2727from torchvision .transforms import Compose
2828
2929from .dataset_utils import manage_advanced_indexing , \
30- SequenceDataset , ClassificationSubset , LazyTargetsConversion , \
30+ SequenceDataset , ClassificationSubset , \
3131 LazyConcatIntTargets , find_list_from_index , ConstantSequence , \
3232 LazyClassMapping , optimize_sequence , SubSequence , LazyConcatTargets , \
33- SubsetWithTargets
33+ TupleTLabel
3434from .dataset_definitions import ITensorDataset , ClassificationDataset , \
3535 IDatasetWithTargets , ISupportedClassificationDataset
3636
4949XTransform = Optional [Callable [[Any ], Any ]]
5050YTransform = Optional [Callable [[Any ], TTargetType ]]
5151
52- SupportedDataset = Union [IDatasetWithTargets , ITensorDataset , Subset ]
52+ SupportedDataset = Union [IDatasetWithTargets , ITensorDataset , Subset ,
53+ ConcatDataset ]
5354
5455
5556class AvalancheDatasetType (Enum ):
@@ -302,8 +303,8 @@ def __radd__(self, other: Dataset) -> 'AvalancheDataset':
302303 return AvalancheConcatDataset ([other , self ])
303304
304305 def __getitem__ (self , idx ) -> Union [T_co , Sequence [T_co ]]:
305- return manage_advanced_indexing (
306- idx , self ._get_single_item , len (self ), self .collate_fn )
306+ return TupleTLabel ( manage_advanced_indexing (
307+ idx , self ._get_single_item , len (self ), self .collate_fn ))
307308
308309 def __len__ (self ):
309310 return len (self ._dataset )
@@ -598,11 +599,10 @@ def _freeze_dataset_group(dataset_copy: TAvalancheDataset,
598599 dataset_copy ._freeze_original_dataset (group_name )
599600
600601 def _get_single_item (self , idx : int ):
601- return self ._process_pattern (
602- self ._dataset [idx ], idx ,
603- isinstance (self ._dataset , AvalancheDataset ))
602+ return self ._process_pattern (self ._dataset [idx ], idx )
604603
605- def _process_pattern (self , element : Tuple , idx : int , has_task_label : bool ):
604+ def _process_pattern (self , element : Tuple , idx : int ):
605+ has_task_label = isinstance (element , TupleTLabel )
606606 if has_task_label :
607607 element = element [:- 1 ]
608608
@@ -611,7 +611,8 @@ def _process_pattern(self, element: Tuple, idx: int, has_task_label: bool):
611611
612612 pattern , label = self ._apply_transforms (pattern , label )
613613
614- return (pattern , label , * element [2 :], self .targets_task_labels [idx ])
614+ return TupleTLabel ((pattern , label , * element [2 :],
615+ self .targets_task_labels [idx ]))
615616
616617 def _apply_transforms (self , pattern : Any , label : int ):
617618 frozen_group = self ._frozen_transforms [self .current_transform_group ]
@@ -710,6 +711,11 @@ def _initialize_targets_sequence(self, dataset, targets,
710711 if targets is not None :
711712 # User defined targets always take precedence
712713 # Note: no adapter is applied!
714+ if len (targets ) != len (dataset ):
715+ raise ValueError (
716+ 'Invalid amount of target labels. It must be equal to the '
717+ 'number of patterns in the dataset. Got {}, expected '
718+ '{}!' .format (len (targets ), len (dataset )))
713719 return targets
714720
715721 if targets_adapter is None :
@@ -731,13 +737,7 @@ def _initialize_task_labels_sequence(
731737 '{}!' .format (len (task_labels ), len (dataset )))
732738 return task_labels
733739
734- if hasattr (dataset , 'targets_task_labels' ):
735- # Dataset is probably a dataset of this class
736- # Suppose that it is
737- return LazyTargetsConversion (dataset .targets_task_labels , int )
738-
739- # No task labels found. Set all task labels to 0 (in a lazy way).
740- return ConstantSequence (0 , len (dataset ))
740+ return _make_task_labels_from_supported_dataset (dataset )
741741
742742 def _initialize_collate_fn (self , dataset , dataset_type , collate_fn ):
743743 if collate_fn is not None :
@@ -952,11 +952,14 @@ def __init__(self,
952952 raise ValueError ('class_mapping is defined but the dataset type'
953953 ' is neither CLASSIFICATION or UNDEFINED.' )
954954
955- if class_mapping :
955+ if class_mapping is not None :
956956 subset = ClassificationSubset (dataset , indices = indices ,
957957 class_mapping = class_mapping )
958+ elif indices is not None :
959+ subset = Subset (dataset , indices = indices )
958960 else :
959- subset = SubsetWithTargets (dataset , indices = indices )
961+ subset = dataset # Exactly like a plain AvalancheDataset
962+
960963 self ._original_dataset = dataset
961964 self ._indices = indices
962965
@@ -972,9 +975,7 @@ def __init__(self,
972975 targets_adapter = targets_adapter )
973976
974977 def _get_single_item (self , idx : int ):
975- return self ._process_pattern (
976- self ._dataset [idx ], idx ,
977- isinstance (self ._original_dataset , AvalancheDataset ))
978+ return self ._process_pattern (self ._dataset [idx ], idx )
978979
979980 def _initialize_targets_sequence (
980981 self , dataset , targets , dataset_type , targets_adapter ) \
@@ -1014,10 +1015,11 @@ def _initialize_task_labels_sequence(
10141015 # case the user just wants to obtain a dataset in which the
10151016 # position of the patterns has been changed according to
10161017 # "indices". This "if" will take care of the corner case, too.
1017- return LazyClassMapping (task_labels , indices = self ._indices )
1018+ return SubSequence (task_labels , indices = self ._indices ,
1019+ converter = int )
10181020 elif len (task_labels ) == len (dataset ):
10191021 # task_labels refers to the subset
1020- return task_labels
1022+ return SubSequence ( task_labels , converter = int )
10211023 else :
10221024 raise ValueError (
10231025 'Invalid amount of task labels. It must be equal to the '
@@ -1026,13 +1028,7 @@ def _initialize_task_labels_sequence(
10261028 len (task_labels ), len (self ._original_dataset ),
10271029 len (dataset )))
10281030
1029- if hasattr (self ._original_dataset , 'targets_task_labels' ):
1030- # The original dataset is probably a dataset of this class
1031- return LazyClassMapping (self ._original_dataset .targets_task_labels ,
1032- indices = self ._indices )
1033-
1034- # No task labels found. Set all task labels to 0 (in a lazy way).
1035- return ConstantSequence (0 , len (dataset ))
1031+ return super ()._initialize_task_labels_sequence (dataset , None )
10361032
10371033
10381034class AvalancheTensorDataset (AvalancheDataset [T_co , TTargetType ]):
@@ -1303,9 +1299,7 @@ def _get_single_item(self, idx: int):
13031299
13041300 single_element = self ._dataset_list [dataset_idx ][internal_idx ]
13051301
1306- return self ._process_pattern (
1307- single_element , idx ,
1308- isinstance (self ._dataset_list [dataset_idx ], AvalancheDataset ))
1302+ return self ._process_pattern (single_element , idx )
13091303
13101304 def _fork_dataset (self : TAvalancheDataset ) -> TAvalancheDataset :
13111305 dataset_copy = super ()._fork_dataset ()
@@ -1353,8 +1347,7 @@ def _initialize_task_labels_sequence(
13531347 concat_t_labels = []
13541348 for dataset_idx , single_dataset in enumerate (self ._dataset_list ):
13551349 concat_t_labels .append (super ()._initialize_task_labels_sequence (
1356- single_dataset , None
1357- ))
1350+ single_dataset , None ))
13581351
13591352 return LazyConcatTargets (concat_t_labels )
13601353
@@ -1595,50 +1588,137 @@ def train_eval_avalanche_datasets(
15951588 return train , test
15961589
15971590
1591+ def _traverse_supported_dataset (
1592+ dataset , values_selector : Callable [[Dataset , List [int ]], List ],
1593+ indices = None ) -> List :
1594+
1595+ initial_error = None
1596+ try :
1597+ result = values_selector (dataset , indices )
1598+ if result is not None :
1599+ return result
1600+ except BaseException as e :
1601+ initial_error = e
1602+
1603+ if isinstance (dataset , Subset ):
1604+ if indices is None :
1605+ indices = range (len (dataset ))
1606+ indices = [dataset .indices [x ] for x in indices ]
1607+ return list (_traverse_supported_dataset (
1608+ dataset .dataset , values_selector , indices ))
1609+
1610+ if isinstance (dataset , ConcatDataset ):
1611+ result = []
1612+ if indices is None :
1613+ for c_dataset in dataset .datasets :
1614+ result += list (_traverse_supported_dataset (
1615+ c_dataset , values_selector , indices ))
1616+ return result
1617+
1618+ datasets_to_indexes = defaultdict (list )
1619+ indexes_to_dataset = []
1620+ datasets_len = []
1621+ recursion_result = []
1622+
1623+ all_size = 0
1624+ for c_dataset in dataset .datasets :
1625+ len_dataset = len (c_dataset )
1626+ datasets_len .append (len_dataset )
1627+ all_size += len_dataset
1628+
1629+ for subset_idx in indices :
1630+ dataset_idx , pattern_idx = \
1631+ find_list_from_index (subset_idx , datasets_len , all_size )
1632+ datasets_to_indexes [dataset_idx ].append (pattern_idx )
1633+ indexes_to_dataset .append (dataset_idx )
1634+
1635+ for dataset_idx , c_dataset in enumerate (dataset .datasets ):
1636+ recursion_result .append (deque (_traverse_supported_dataset (
1637+ c_dataset , values_selector , datasets_to_indexes [dataset_idx ])))
1638+
1639+ result = []
1640+ for idx in range (len (indices )):
1641+ dataset_idx = indexes_to_dataset [idx ]
1642+ result .append (recursion_result [dataset_idx ].popleft ())
1643+
1644+ return result
1645+
1646+ if initial_error is not None :
1647+ raise initial_error
1648+
1649+ raise ValueError ('Error: can\' t find the needed data in the given dataset' )
1650+
1651+
1652+ def _select_targets (dataset , indices ):
1653+ if hasattr (dataset , 'targets' ):
1654+ # Standard supported dataset
1655+ found_targets = dataset .targets
1656+ elif hasattr (dataset , 'tensors' ):
1657+ # Support for PyTorch TensorDataset
1658+ if len (dataset .tensors ) < 2 :
1659+ raise ValueError ('Tensor dataset has not enough tensors: '
1660+ 'at least 2 are required.' )
1661+ found_targets = dataset .tensors [1 ]
1662+ else :
1663+ raise ValueError (
1664+ 'Unsupported dataset: must have a valid targets field '
1665+ 'or has to be a Tensor Dataset with at least 2 '
1666+ 'Tensors' )
1667+
1668+ if indices is not None :
1669+ found_targets = SubSequence (found_targets , indices = indices )
1670+
1671+ return found_targets
1672+
1673+
1674+ def _select_task_labels (dataset , indices ):
1675+ found_task_labels = None
1676+ if hasattr (dataset , 'targets_task_labels' ):
1677+ found_task_labels = dataset .targets_task_labels
1678+
1679+ if found_task_labels is None :
1680+ if isinstance (dataset , (Subset , ConcatDataset )):
1681+ return None # Continue traversing
1682+
1683+ if found_task_labels is None :
1684+ if indices is None :
1685+ return ConstantSequence (0 , len (dataset ))
1686+ return ConstantSequence (0 , len (indices ))
1687+
1688+ if indices is not None :
1689+ found_task_labels = SubSequence (found_task_labels , indices = indices )
1690+
1691+ return found_task_labels
1692+
1693+
15981694def _make_target_from_supported_dataset (
15991695 dataset : SupportedDataset ,
16001696 converter : Callable [[Any ], TTargetType ] = None ) -> \
16011697 Sequence [TTargetType ]:
16021698 if isinstance (dataset , AvalancheDataset ):
16031699 if converter is None :
16041700 return dataset .targets
1605- elif isinstance (dataset .targets , LazyTargetsConversion ) and \
1701+ elif isinstance (dataset .targets ,
1702+ (SubSequence , LazyConcatTargets )) and \
16061703 dataset .targets .converter == converter :
16071704 return dataset .targets
16081705 elif isinstance (dataset .targets , LazyClassMapping ) and converter == int :
16091706 # LazyClassMapping already outputs int targets
16101707 return dataset .targets
16111708
1612- # Support for PyTorch "Subset"
1613- subset_indices = False
1614- indices = range (len (dataset ))
1615- while isinstance (dataset , Subset ):
1616- subset_indices = True
1617- indices = [dataset .indices [x ] for x in indices ]
1618- dataset = dataset .dataset
1709+ targets = _traverse_supported_dataset (dataset , _select_targets )
16191710
1620- if hasattr (dataset , 'targets' ):
1621- # Standard supported dataset
1622- found_targets = dataset .targets
1623- elif hasattr (dataset , 'tensors' ):
1624- # Support for PyTorch TensorDataset
1625- if len (dataset .tensors ) < 2 :
1626- raise ValueError ('Tensor dataset has not enough tensors: '
1627- 'at least 2 are required.' )
1628- found_targets = dataset .tensors [1 ]
1629- else :
1630- raise ValueError ('Unsupported dataset: must have a valid targets field '
1631- 'or has to be a Tensor Dataset with at least 2 '
1632- 'Tensors' )
1711+ return SubSequence (targets , converter = converter )
16331712
1634- if subset_indices :
1635- found_targets = SubSequence (found_targets , indices )
16361713
1637- return LazyTargetsConversion (found_targets , converter = converter )
1714+ def _make_task_labels_from_supported_dataset (dataset : SupportedDataset ) -> \
1715+ Sequence [int ]:
1716+ if isinstance (dataset , AvalancheDataset ):
1717+ return dataset .targets_task_labels
16381718
1719+ task_labels = _traverse_supported_dataset (dataset , _select_task_labels )
16391720
1640- def _is_tensor_dataset (dataset : SupportedDataset ) -> bool :
1641- return hasattr (dataset , 'tensors' ) and len (dataset .tensors ) >= 2
1721+ return SubSequence (task_labels , converter = int )
16421722
16431723
16441724__all__ = [
0 commit comments