Skip to content

Commit 2dcc7e8

Browse files
authored
Merge pull request #832 from superannotateai/wm_classes
Wm classes
2 parents f6384e6 + 3680feb commit 2dcc7e8

File tree

17 files changed

+424
-130
lines changed

17 files changed

+424
-130
lines changed

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1973,6 +1973,118 @@ def search_annotation_classes(
19731973
raise AppException(response.errors)
19741974
return BaseSerializer.serialize_iterable(response.data)
19751975

1976+
def update_annotation_class(
1977+
self,
1978+
project: Union[NotEmptyStr, int],
1979+
name: NotEmptyStr,
1980+
attribute_groups: List[dict],
1981+
):
1982+
"""Updates an existing annotation class by submitting a full, updated attribute_groups payload.
1983+
You can add new attribute groups, add new attribute values, rename attribute groups, rename attribute values,
1984+
delete attribute groups, delete attribute values, update attribute group types, update default attributes,
1985+
and update the required state.
1986+
1987+
.. warning::
1988+
This operation replaces the entire attribute group structure of the annotation class.
1989+
Any attribute groups or attribute values omitted from the payload will be permanently removed.
1990+
Existing annotations that reference removed attribute groups or attributes will lose their associated values.
1991+
1992+
:param project: The name or ID of the project.
1993+
:type project: Union[str, int]
1994+
1995+
:param name: The name of the annotation class to update.
1996+
:type name: str
1997+
1998+
:param attribute_groups: The full list of attribute groups for the class.
1999+
Each attribute group may contain:
2000+
2001+
- id (optional, required for existing groups)
2002+
- group_type (required): "radio", "checklist", "text", "numeric", or "ocr"
2003+
- name (required)
2004+
- isRequired (optional)
2005+
- default_value (optional)
2006+
- attributes (required, list)
2007+
2008+
Each attribute may contain:
2009+
2010+
- id (optional, required for existing attributes)
2011+
- name (required)
2012+
2013+
:type attribute_groups: list of dicts
2014+
2015+
Request Example:
2016+
::
2017+
2018+
# Retrieve existing annotation class
2019+
classes = client.search_annotation_classes(project="Medical Project", name="Organ")
2020+
existing_class = classes[0]
2021+
2022+
# Modify attribute groups
2023+
updated_groups = existing_class["attribute_groups"]
2024+
2025+
# Add a new attribute to an existing group
2026+
updated_groups[0]["attributes"].append({"name": "Kidney"})
2027+
2028+
# Add a new attribute group
2029+
updated_groups.append({
2030+
"group_type": "radio",
2031+
"name": "Severity",
2032+
"attributes": [
2033+
{"name": "Mild"},
2034+
{"name": "Moderate"},
2035+
{"name": "Severe"}
2036+
],
2037+
"default_value": "Mild"
2038+
})
2039+
2040+
# Update the annotation class
2041+
client.update_annotation_class(
2042+
project="Medical Project",
2043+
name="Organ",
2044+
attribute_groups=updated_groups
2045+
)
2046+
"""
2047+
project = self.controller.get_project(project)
2048+
2049+
# Find the annotation class by nam
2050+
annotation_classes = self.controller.annotation_classes.list(
2051+
condition=Condition("project_id", project.id, EQ)
2052+
).data
2053+
2054+
annotation_class = next(
2055+
(c for c in annotation_classes if c["name"] == name), None
2056+
)
2057+
2058+
if not annotation_class:
2059+
raise AppException("Annotation class not found in project.")
2060+
2061+
# Parse and validate attribute groups
2062+
annotation_class["attributeGroups"] = attribute_groups
2063+
for group in annotation_class["attributeGroups"]:
2064+
if "isRequired" in group:
2065+
group["is_required"] = group.pop("isRequired")
2066+
2067+
from lib.core.entities import WMAnnotationClassEntity
2068+
2069+
try:
2070+
# validate annotation class
2071+
annotation_class = WMAnnotationClassEntity.parse_obj(
2072+
BaseSerializer(annotation_class).serialize()
2073+
)
2074+
except ValidationError as e:
2075+
raise AppException(wrap_error(e))
2076+
2077+
# Update the annotation class with new attribute groups
2078+
2079+
response = self.controller.annotation_classes.update(
2080+
project=project, annotation_class=annotation_class
2081+
)
2082+
2083+
if response.errors:
2084+
raise AppException(response.errors)
2085+
2086+
return BaseSerializer(response.data).serialize(by_alias=True)
2087+
19762088
def set_project_status(self, project: NotEmptyStr, status: PROJECT_STATUS):
19772089
"""Set project status
19782090
@@ -2845,7 +2957,9 @@ def create_annotation_class(
28452957
)
28462958
if response.errors:
28472959
raise AppException(response.errors)
2848-
return BaseSerializer(response.data).serialize(exclude_unset=True)
2960+
return BaseSerializer(response.data).serialize(
2961+
exclude_unset=True, by_alias=False
2962+
)
28492963

28502964
def delete_annotation_class(
28512965
self, project: NotEmptyStr, annotation_class: Union[dict, NotEmptyStr]

src/superannotate/lib/core/entities/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
from lib.core.entities.project import WorkflowEntity
2525
from lib.core.entities.project_entities import BaseEntity
2626
from lib.core.entities.project_entities import S3FileEntity
27+
from lib.core.entities.work_managament import WMAnnotationClassEntity
2728
from lib.core.entities.work_managament import WMProjectUserEntity
2829

30+
2931
__all__ = [
3032
# base
3133
"ConfigEntity",
@@ -60,4 +62,6 @@
6062
# multimodal
6163
"FormModel",
6264
"generate_classes_from_form",
65+
# work management
66+
"WMAnnotationClassEntity",
6367
]

src/superannotate/lib/core/entities/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,23 @@
2020
try:
2121
from pydantic import AbstractSetIntStr # noqa
2222
from pydantic import MappingIntStrAny # noqa
23+
2324
except ImportError:
2425
pass
2526
_missing = object()
2627

28+
from lib.core.pydantic_v1 import Color
29+
from lib.core.pydantic_v1 import ColorType
30+
from lib.core.pydantic_v1 import validator
31+
32+
33+
class HexColor(BaseModel):
34+
__root__: ColorType
35+
36+
@validator("__root__", pre=True)
37+
def validate_color(cls, v):
38+
return "#{:02X}{:02X}{:02X}".format(*Color(v).as_rgb_tuple())
39+
2740

2841
class StringDate(datetime):
2942
@classmethod

src/superannotate/lib/core/entities/classes.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,24 @@
1-
from datetime import datetime
21
from enum import Enum
32
from typing import Any
43
from typing import List
54
from typing import Optional
65

7-
from lib.core.entities.base import BaseModel
8-
from lib.core.entities.base import parse_datetime
6+
from lib.core.entities.base import HexColor
7+
from lib.core.entities.base import TimedBaseModel
98
from lib.core.enums import BaseTitledEnum
109
from lib.core.enums import ClassTypeEnum
11-
from lib.core.pydantic_v1 import Color
12-
from lib.core.pydantic_v1 import ColorType
1310
from lib.core.pydantic_v1 import Extra
1411
from lib.core.pydantic_v1 import Field
1512
from lib.core.pydantic_v1 import StrictInt
1613
from lib.core.pydantic_v1 import StrictStr
17-
from lib.core.pydantic_v1 import validator
14+
1815

1916
DATE_REGEX = r"\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]\d:[0-5]\d(?:\.\d{3})Z"
2017
DATE_TIME_FORMAT_ERROR_MESSAGE = (
2118
"does not match expected format YYYY-MM-DDTHH:MM:SS.fffZ"
2219
)
2320

2421

25-
class HexColor(BaseModel):
26-
__root__: ColorType
27-
28-
@validator("__root__", pre=True)
29-
def validate_color(cls, v):
30-
return "#{:02X}{:02X}{:02X}".format(*Color(v).as_rgb_tuple())
31-
32-
3322
class GroupTypeEnum(str, Enum):
3423
RADIO = "radio"
3524
CHECKLIST = "checklist"
@@ -38,22 +27,6 @@ class GroupTypeEnum(str, Enum):
3827
OCR = "ocr"
3928

4029

41-
class StringDate(datetime):
42-
@classmethod
43-
def __get_validators__(cls):
44-
yield parse_datetime
45-
yield cls.validate
46-
47-
@classmethod
48-
def validate(cls, v: datetime):
49-
return v.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
50-
51-
52-
class TimedBaseModel(BaseModel):
53-
createdAt: StringDate = Field(None, alias="createdAt")
54-
updatedAt: StringDate = Field(None, alias="updatedAt")
55-
56-
5730
class Attribute(TimedBaseModel):
5831
id: Optional[StrictInt]
5932
group_id: Optional[StrictInt]

src/superannotate/lib/core/entities/work_managament.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@
66
from typing import Optional
77
from typing import Union
88

9+
from lib.core.entities.base import HexColor
910
from lib.core.entities.base import TimedBaseModel
11+
from lib.core.enums import WMClassTypeEnum
12+
from lib.core.enums import WMGroupTypeEnum
1013
from lib.core.enums import WMUserStateEnum
1114
from lib.core.exceptions import AppException
1215
from lib.core.pydantic_v1 import BaseModel
1316
from lib.core.pydantic_v1 import Extra
1417
from lib.core.pydantic_v1 import Field
1518
from lib.core.pydantic_v1 import parse_datetime
1619
from lib.core.pydantic_v1 import root_validator
20+
from lib.core.pydantic_v1 import StrictInt
21+
from lib.core.pydantic_v1 import StrictStr
1722
from lib.core.pydantic_v1 import validator
1823

1924

@@ -209,3 +214,71 @@ def check_weight_and_value(cls, values):
209214
):
210215
raise AppException("Weight and Value must both be set or both be None.")
211216
return values
217+
218+
219+
class WMAttribute(TimedBaseModel):
220+
id: Optional[StrictInt]
221+
group_id: Optional[StrictInt]
222+
project_id: Optional[StrictInt]
223+
name: Optional[StrictStr]
224+
default: Any
225+
226+
class Config:
227+
extra = Extra.ignore
228+
229+
def __hash__(self):
230+
return hash(f"{self.id}{self.group_id}{self.name}")
231+
232+
233+
class WMAttributeGroup(TimedBaseModel):
234+
id: Optional[StrictInt]
235+
group_type: Optional[WMGroupTypeEnum]
236+
class_id: Optional[StrictInt]
237+
name: Optional[StrictStr]
238+
isRequired: bool = Field(default=False, alias="is_required")
239+
attributes: Optional[List[WMAttribute]]
240+
default_value: Any
241+
242+
class Config:
243+
extra = Extra.ignore
244+
245+
def __hash__(self):
246+
return hash(f"{self.id}{self.class_id}{self.name}")
247+
248+
249+
class WMAnnotationClassEntity(TimedBaseModel):
250+
id: Optional[StrictInt]
251+
project_id: Optional[StrictInt]
252+
type: WMClassTypeEnum = WMClassTypeEnum.OBJECT
253+
name: StrictStr
254+
color: HexColor
255+
attribute_groups: List[WMAttributeGroup] = Field(
256+
default=[], alias="attributeGroups"
257+
)
258+
259+
def __hash__(self):
260+
return hash(f"{self.id}{self.type}{self.name}")
261+
262+
class Config:
263+
extra = Extra.ignore
264+
json_encoders = {
265+
HexColor: lambda v: v.__root__,
266+
# WMClassTypeEnum: lambda v: v.name,
267+
}
268+
validate_assignment = True
269+
270+
@validator("type", pre=True)
271+
def validate_type(cls, v):
272+
if isinstance(v, WMClassTypeEnum):
273+
return v
274+
if isinstance(v, str):
275+
# Try by value first (e.g., "object")
276+
for member in WMClassTypeEnum:
277+
if member.value == v:
278+
return member
279+
# Try by name (e.g., "OBJECT")
280+
try:
281+
return WMClassTypeEnum[v.upper()]
282+
except KeyError:
283+
pass
284+
raise ValueError(f"Invalid type: {v}")

src/superannotate/lib/core/enums.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,20 @@ def get_value(cls, name):
175175
return cls.OBJECT.value
176176

177177

178+
class WMClassTypeEnum(Enum):
179+
OBJECT = "object"
180+
TAG = "tag"
181+
RELATIONSHIP = "relationship"
182+
183+
184+
class WMGroupTypeEnum(Enum):
185+
RADIO = "radio"
186+
CHECKLIST = "checklist"
187+
NUMERIC = "numeric"
188+
TEXT = "text"
189+
OCR = "ocr"
190+
191+
178192
class IntegrationTypeEnum(BaseTitledEnum):
179193
AWS = "aws", 1
180194
GCP = "gcp", 2
@@ -193,7 +207,7 @@ class TrainingStatus(BaseTitledEnum):
193207
FAILED_AFTER_EVALUATION_WITH_SAVE_MODEL = "FailedAfterEvaluationWithSavedModel", 6
194208

195209

196-
class CustomFieldEntityEnum(str, Enum):
210+
class CustomFieldEntityEnum(Enum):
197211
CONTRIBUTOR = "Contributor"
198212
TEAM = "Team"
199213
PROJECT = "Project"
@@ -207,6 +221,6 @@ class CustomFieldType(Enum):
207221
NUMERIC = 5
208222

209223

210-
class WMUserStateEnum(str, Enum):
224+
class WMUserStateEnum(Enum):
211225
Pending = "PENDING"
212226
Confirmed = "CONFIRMED"

src/superannotate/lib/core/service_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def __str__(self):
138138
return f"Status: {self.status_code}, Error {self.error}"
139139

140140

141+
class WMClassesResponse(ServiceResponse):
142+
res_data: List[entities.WMAnnotationClassEntity] = None
143+
144+
141145
class BaseItemResponse(ServiceResponse):
142146
res_data: entities.BaseItemEntity = None
143147

@@ -186,6 +190,10 @@ class IntegrationListResponse(ServiceResponse):
186190
res_data: _IntegrationResponse
187191

188192

193+
class AnnotationClassResponse(ServiceResponse):
194+
res_data: entities.AnnotationClassEntity = None
195+
196+
189197
class AnnotationClassListResponse(ServiceResponse):
190198
res_data: List[entities.AnnotationClassEntity] = None
191199

0 commit comments

Comments
 (0)