Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ minversion = 3.7
log_cli=true
python_files = test_*.py
;pytest_plugins = ['pytest_profiling']
addopts = -n 6 --dist loadscope
;addopts = -n 6 --dist loadscope
11 changes: 3 additions & 8 deletions src/superannotate/lib/app/interface/sdk_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from lib.core.jsx_conditions import Join
from lib.core.jsx_conditions import Fields
from lib.core.entities.items import ProjectCategoryEntity
from lib.core.entities import WMAnnotationClassEntity

logger = logging.getLogger("sa")

Expand Down Expand Up @@ -2059,13 +2060,7 @@ def update_annotation_class(
raise AppException("Annotation class not found in project.")

# Parse and validate attribute groups
annotation_class["attributeGroups"] = attribute_groups
for group in annotation_class["attributeGroups"]:
if "isRequired" in group:
group["is_required"] = group.pop("isRequired")

from lib.core.entities import WMAnnotationClassEntity

annotation_class["attribute_groups"] = attribute_groups
try:
# validate annotation class
annotation_class = WMAnnotationClassEntity.parse_obj(
Expand All @@ -2083,7 +2078,7 @@ def update_annotation_class(
if response.errors:
raise AppException(response.errors)

return BaseSerializer(response.data).serialize(by_alias=True)
return BaseSerializer(response.data).serialize(by_alias=False)

def set_project_status(self, project: NotEmptyStr, status: PROJECT_STATUS):
"""Set project status
Expand Down
38 changes: 38 additions & 0 deletions src/superannotate/lib/core/entities/work_managament.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,42 @@ class WMAttributeGroup(TimedBaseModel):
default_value: Any

class Config:
allow_population_by_field_name = True
extra = Extra.ignore

def __hash__(self):
return hash(f"{self.id}{self.class_id}{self.name}")

@validator("group_type", pre=True)
def validate_group_type(cls, v):
if v is None:
return v
if isinstance(v, WMGroupTypeEnum):
return v
if isinstance(v, str):
# Try by value first (e.g., "radio")
for member in WMGroupTypeEnum:
if member.value == v.lower():
return member
# Try by name (e.g., "RADIO" or "radio")
try:
return WMGroupTypeEnum[v.upper()]
except KeyError:
pass
raise ValueError(f"Invalid group_type: {v}")

def dict(self, *args, **kwargs):
by_alias = kwargs.get("by_alias", False)
data = super().dict(*args, **kwargs)

if by_alias and "group_type" in data:
if isinstance(data["group_type"], WMGroupTypeEnum):
data["group_type"] = data["group_type"].name
elif not by_alias and "group_type" in data:
if isinstance(data["group_type"], WMGroupTypeEnum):
data["group_type"] = data["group_type"].value
return data


class WMAnnotationClassEntity(TimedBaseModel):
id: Optional[StrictInt]
Expand All @@ -260,13 +291,20 @@ def __hash__(self):
return hash(f"{self.id}{self.type}{self.name}")

class Config:
allow_population_by_field_name = True
extra = Extra.ignore
json_encoders = {
HexColor: lambda v: v.__root__,
# WMClassTypeEnum: lambda v: v.name,
}
validate_assignment = True

def dict(self, *args, **kwargs):
data = super().dict(*args, **kwargs)
if "type" in data and isinstance(data["type"], WMClassTypeEnum):
data["type"] = data["type"].value
return data

@validator("type", pre=True)
def validate_type(cls, v):
if isinstance(v, WMClassTypeEnum):
Expand Down
2 changes: 1 addition & 1 deletion src/superannotate/lib/core/service_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __str__(self):


class WMClassesResponse(ServiceResponse):
res_data: List[entities.WMAnnotationClassEntity] = None
res_data: entities.WMAnnotationClassEntity = None


class BaseItemResponse(ServiceResponse):
Expand Down
3 changes: 2 additions & 1 deletion src/superannotate/lib/core/serviceproviders.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lib.core.service_types import UploadAnnotationsResponse
from lib.core.service_types import UserLimitsResponse
from lib.core.service_types import UserResponse
from lib.core.service_types import WMClassesResponse
from lib.core.service_types import WMCustomFieldResponse
from lib.core.service_types import WMProjectListResponse
from lib.core.service_types import WMScoreListResponse
Expand Down Expand Up @@ -250,7 +251,7 @@ def update_annotation_class(
project_id: int,
class_id: int,
data: WMAnnotationClassEntity,
) -> ServiceResponse:
) -> WMClassesResponse:
raise NotImplementedError


Expand Down
2 changes: 1 addition & 1 deletion src/superannotate/lib/core/usecases/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def execute(self):
"role", constants.UserRole.CONTRIBUTOR.value, OperatorEnum.EQ
)
if user_to_retrieve:
_filter &= (Filter("email", user_to_retrieve, OperatorEnum.IN),)
_filter &= Filter("email", user_to_retrieve, OperatorEnum.IN)
users = self._service_provider.work_management.list_users(
_filter,
parent_entity=CustomFieldEntityEnum.TEAM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from lib.core.jsx_conditions import Filter
from lib.core.jsx_conditions import OperatorEnum
from lib.core.jsx_conditions import Query
from lib.core.service_types import AnnotationClassResponse
from lib.core.service_types import FolderListResponse
from lib.core.service_types import ListCategoryResponse
from lib.core.service_types import ListProjectCategoryResponse
Expand Down Expand Up @@ -547,16 +546,16 @@ def update_annotation_class(
project_id: int,
class_id: int,
data: WMAnnotationClassEntity,
) -> WMClassesResponse:
) -> ServiceResponse:
return self.client.request(
url=self.URL_UPDATE_ANNOTATION_CLASS.format(class_id=class_id),
method="patch",
data=data,
data=data.dict(exclude_unset=True, by_alias=True),
headers={
"x-sa-entity-context": self._generate_context(
team_id=self.client.team_id, project_id=project_id
),
},
content_type=AnnotationClassResponse,
content_type=WMClassesResponse,
dispatcher="data",
)
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,32 @@ def test_update_annotation_class_attribute_groups(self):

# Verify updates
classes = sa.search_annotation_classes(self.PROJECT_NAME, "test_update")
self.assertDictEqual(classes[0], update_response)
# Verify response matches the current class state
self.assertEqual(update_response["id"], classes[0]["id"])
self.assertEqual(update_response["name"], classes[0]["name"])
self.assertEqual(update_response["color"], classes[0]["color"])
self.assertEqual(update_response["type"], classes[0]["type"])
self.assertEqual(
len(update_response["attribute_groups"]),
len(classes[0]["attribute_groups"]),
)

# Verify each attribute group matches
for resp_group, class_group in zip(
update_response["attribute_groups"], classes[0]["attribute_groups"]
):
self.assertEqual(resp_group["name"], class_group["name"])
self.assertEqual(resp_group["group_type"], class_group["group_type"])
self.assertEqual(resp_group["isRequired"], class_group["isRequired"])
self.assertEqual(
len(resp_group["attributes"]), len(class_group["attributes"])
)

# Verify each attribute matches
for resp_attr, class_attr in zip(
resp_group["attributes"], class_group["attributes"]
):
self.assertEqual(resp_attr["name"], class_attr["name"])
self.assertEqual(len(classes), 1)
updated_class = classes[0]

Expand All @@ -347,13 +372,188 @@ def test_update_annotation_class_attribute_groups(self):
)
self.assertEqual(color_group["group_type"], "checklist")
self.assertEqual(len(color_group["attributes"]), 3)
# check noting updated
response = sa.update_annotation_class(

def test_update_annotation_class_rename_attributes(self):
# Create annotation class
sa.create_annotation_class(
self.PROJECT_NAME,
"test_update",
"test_update_rename",
"#00FF00",
attribute_groups=[
{
"name": "Quality",
"group_type": "radio",
"attributes": [{"name": "Good"}, {"name": "Bad"}],
}
],
)

# Retrieve and rename attribute
classes = sa.search_annotation_classes(self.PROJECT_NAME, "test_update_rename")
updated_groups = classes[0]["attribute_groups"]
updated_groups[0]["attributes"][0]["name"] = "Excellent"
updated_groups[0]["name"] = "Rating"

# Update
sa.update_annotation_class(
self.PROJECT_NAME, "test_update_rename", attribute_groups=updated_groups
)

# Verify
classes = sa.search_annotation_classes(self.PROJECT_NAME, "test_update_rename")
rating_group = classes[0]["attribute_groups"][0]
self.assertEqual(rating_group["name"], "Rating")
self.assertEqual(rating_group["attributes"][0]["name"], "Excellent")

def test_update_annotation_class_delete_attributes(self):
# Create annotation class with multiple attributes
sa.create_annotation_class(
self.PROJECT_NAME,
"test_update_delete",
"#0000FF",
attribute_groups=[
{
"name": "Status",
"group_type": "checklist",
"attributes": [
{"name": "Active"},
{"name": "Inactive"},
{"name": "Pending"},
],
}
],
)

# Retrieve and remove one attribute
classes = sa.search_annotation_classes(self.PROJECT_NAME, "test_update_delete")
updated_groups = classes[0]["attribute_groups"]
updated_groups[0]["attributes"] = [
attr
for attr in updated_groups[0]["attributes"]
if attr["name"] != "Pending"
]

# Update
sa.update_annotation_class(
self.PROJECT_NAME, "test_update_delete", attribute_groups=updated_groups
)

# Verify
classes = sa.search_annotation_classes(self.PROJECT_NAME, "test_update_delete")
status_group = classes[0]["attribute_groups"][0]
self.assertEqual(len(status_group["attributes"]), 2)
attribute_names = [attr["name"] for attr in status_group["attributes"]]
self.assertNotIn("Pending", attribute_names)

def test_update_annotation_class_change_required_and_default(self):
# Create annotation class
sa.create_annotation_class(
self.PROJECT_NAME,
"test_update_required",
"#FFFF00",
attribute_groups=[
{
"name": "Priority",
"group_type": "radio",
"attributes": [
{"name": "Low"},
{"name": "Medium"},
{"name": "High"},
],
"default_value": "Low",
"isRequired": False,
}
],
)

# Retrieve and update required state and default value
classes = sa.search_annotation_classes(
self.PROJECT_NAME, "test_update_required"
)
updated_groups = classes[0]["attribute_groups"]
updated_groups[0]["isRequired"] = True
updated_groups[0]["attributes"][0]["default"] = 0
updated_groups[0]["attributes"][1]["default"] = 1

# Update
sa.update_annotation_class(
self.PROJECT_NAME, "test_update_required", attribute_groups=updated_groups
)

# Verify
classes = sa.search_annotation_classes(
self.PROJECT_NAME, "test_update_required"
)
priority_group = classes[0]["attribute_groups"][0]
self.assertTrue(priority_group["isRequired"])
self.assertEqual(priority_group["default_value"], "Medium")

def test_update_annotation_class_change_group_type(self):
# Create annotation class
sa.create_annotation_class(
self.PROJECT_NAME,
"test_update_type",
"#FF00FF",
attribute_groups=[
{
"name": "Options",
"group_type": "radio",
"attributes": [{"name": "Option1"}, {"name": "Option2"}],
}
],
)

# Retrieve and change group type
classes = sa.search_annotation_classes(self.PROJECT_NAME, "test_update_type")
updated_groups = classes[0]["attribute_groups"]
updated_groups[0]["group_type"] = "checklist"

# Update
sa.update_annotation_class(
self.PROJECT_NAME, "test_update_type", attribute_groups=updated_groups
)

# Verify
classes = sa.search_annotation_classes(self.PROJECT_NAME, "test_update_type")
options_group = classes[0]["attribute_groups"][0]
self.assertEqual(options_group["group_type"], "checklist")

def test_update_annotation_class_no_changes(self):
# Create annotation class
sa.create_annotation_class(
self.PROJECT_NAME,
"test_update_nochange",
"#00FFFF",
attribute_groups=[
{
"name": "Category",
"group_type": "radio",
"attributes": [{"name": "A"}, {"name": "B"}],
}
],
)

# Retrieve class
classes = sa.search_annotation_classes(
self.PROJECT_NAME, "test_update_nochange"
)

# Update with same data
update_response = sa.update_annotation_class(
self.PROJECT_NAME,
"test_update_nochange",
attribute_groups=classes[0]["attribute_groups"],
)
self.assertDictEqual(response, classes[0])

# Verify response matches the current class state
self.assertEqual(update_response["id"], classes[0]["id"])
self.assertEqual(update_response["name"], classes[0]["name"])
self.assertEqual(update_response["color"], classes[0]["color"])
self.assertEqual(update_response["type"], classes[0]["type"])
self.assertEqual(
len(update_response["attribute_groups"]),
len(classes[0]["attribute_groups"]),
)


class TestVideoCreateAnnotationClasses(BaseTestCase):
Expand Down