@@ -19,7 +19,8 @@ class AnnotatedObjectsDataset(Dataset):
1919 def __init__ (self , data_path : Union [str , Path ], split : SplitType , keys : List [str ], target_image_size : int ,
2020 min_object_area : float , min_objects_per_image : int , max_objects_per_image : int ,
2121 crop_method : CropMethodType , random_flip : bool , no_tokens : int , use_group_parameter : bool ,
22- encode_crop : bool , category_allow_list_target : str , category_mapping_target : str ):
22+ encode_crop : bool , category_allow_list_target : str , category_mapping_target : str ,
23+ no_object_classes : Optional [int ] = None ):
2324 self .data_path = data_path
2425 self .split = split
2526 self .keys = keys
@@ -48,6 +49,7 @@ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str
4849 self .category_mapping = {}
4950 if category_mapping_target :
5051 self .category_mapping = load_object_from_string (category_mapping_target )
52+ self .no_object_classes = no_object_classes
5153
5254 def build_paths (self , top_level : Union [str , Path ]) -> Dict [str , Path ]:
5355 top_level = Path (top_level )
@@ -104,7 +106,7 @@ def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool],
104106
105107 @property
106108 def no_classes (self ) -> int :
107- return len (self .categories )
109+ return self . no_object_classes if self . no_object_classes else len (self .categories )
108110
109111 @property
110112 def conditional_builders (self ) -> ObjectsCenterPointsConditionalBuilder :
0 commit comments