11from pathlib import Path
22from typing import Optional , List , Callable , Dict , Any , Union
3+ import warnings
34
45import PIL .Image as pil_image
56from torch import Tensor
89
910from taming .data .conditional_builder .objects_bbox import ObjectsBoundingBoxConditionalBuilder
1011from taming .data .conditional_builder .objects_center_points import ObjectsCenterPointsConditionalBuilder
12+ from taming .data .conditional_builder .utils import load_object_from_string
1113from taming .data .helper_types import BoundingBox , CropMethodType , Image , Annotation , SplitType
1214from taming .data .image_transforms import CenterCropReturnCoordinates , RandomCrop1dReturnCoordinates , \
1315 Random2dCropReturnCoordinates , RandomHorizontalFlipReturn , convert_pil_to_tensor
@@ -17,7 +19,7 @@ class AnnotatedObjectsDataset(Dataset):
1719 def __init__ (self , data_path : Union [str , Path ], split : SplitType , keys : List [str ], target_image_size : int ,
1820 min_object_area : float , min_objects_per_image : int , max_objects_per_image : int ,
1921 crop_method : CropMethodType , random_flip : bool , no_tokens : int , use_group_parameter : bool ,
20- encode_crop : bool ):
22+ encode_crop : bool , category_allow_list_target : str , category_mapping_target : str ):
2123 self .data_path = data_path
2224 self .split = split
2325 self .keys = keys
@@ -40,6 +42,12 @@ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str
4042 self .transform_functions : List [Callable ] = self .setup_transform (target_image_size , crop_method , random_flip )
4143 self .paths = self .build_paths (self .data_path )
4244 self ._conditional_builders = None
45+ if category_allow_list_target :
46+ allow_list = load_object_from_string (category_allow_list_target )
47+ self .category_allow_list = {name for name , _ in allow_list }
48+ self .category_mapping = {}
49+ if category_mapping_target :
50+ self .category_mapping = load_object_from_string (category_mapping_target )
4351
4452 def build_paths (self , top_level : Union [str , Path ]) -> Dict [str , Path ]:
4553 top_level = Path (top_level )
@@ -123,12 +131,22 @@ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
123131 return self ._conditional_builders
124132
125133 def filter_categories (self ) -> None :
126- pass
134+ if self .category_allow_list :
135+ self .categories = {id_ : cat for id_ , cat in self .categories .items () if cat .name in self .category_allow_list }
136+ if self .category_mapping :
137+ self .categories = {id_ : cat for id_ , cat in self .categories .items () if cat .id not in self .category_mapping }
127138
128139 def setup_category_id_and_number (self ) -> None :
129140 self .category_ids = list (self .categories .keys ())
130141 self .category_ids .sort ()
142+ if '/m/01s55n' in self .category_ids :
143+ self .category_ids .remove ('/m/01s55n' )
144+ self .category_ids .append ('/m/01s55n' )
131145 self .category_number = {category_id : i for i , category_id in enumerate (self .category_ids )}
146+ if self .category_allow_list is not None and self .category_mapping is None \
147+ and len (self .category_ids ) != len (self .category_allow_list ):
148+ warnings .warn ('Unexpected number of categories: Mismatch with category_allow_list. '
149+ 'Make sure all names in category_allow_list exist.' )
132150
133151 def clean_up_annotations_and_image_descriptions (self ) -> None :
134152 image_id_set = set (self .image_ids )
0 commit comments