Skip to content

Commit 0821d00

Browse files
committed
Allow to externally set class number for compatibility
1 parent 6876a0f commit 0821d00

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

taming/data/annotated_objects_dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ class AnnotatedObjectsDataset(Dataset):
1919
def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
2020
min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
2121
crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
22-
encode_crop: bool, category_allow_list_target: str, category_mapping_target: str):
22+
encode_crop: bool, category_allow_list_target: str, category_mapping_target: str,
23+
no_object_classes: Optional[int] = None):
2324
self.data_path = data_path
2425
self.split = split
2526
self.keys = keys
@@ -48,6 +49,7 @@ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str
4849
self.category_mapping = {}
4950
if category_mapping_target:
5051
self.category_mapping = load_object_from_string(category_mapping_target)
52+
self.no_object_classes = no_object_classes
5153

5254
def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
5355
top_level = Path(top_level)
@@ -104,7 +106,7 @@ def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool],
104106

105107
@property
106108
def no_classes(self) -> int:
107-
return len(self.categories)
109+
return self.no_object_classes if self.no_object_classes else len(self.categories)
108110

109111
@property
110112
def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:

0 commit comments

Comments
 (0)