Skip to content

Commit 3181b66

Browse files
authored
Task-based annotation fixes (#445)
* optional slice param * fix jean comments * revert constants * fix pylint * fix import issues * fixing dup check * fix bug * more bugs * fix another bug * another fix
1 parent 12c8d8d commit 3181b66

File tree

5 files changed

+87
-59
lines changed

5 files changed

+87
-59
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## [0.17.8](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.7) - 2024-11-05
8+
## [0.17.8](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.8) - 2025-01-02
99

1010
### Added
1111
- Adding `only_most_recent_tasks` parameter for `dataset.scene_and_annotation_generator()` and `dataset.items_and_annotation_generator()` to accommodate for multiple sets of ground truth caused by relabeled tasks. Also returns the task_id in the annotation results.

nucleus/annotation.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class BoxAnnotation(Annotation): # pylint: disable=R0902
159159
metadata: Optional[Dict] = None
160160
embedding_vector: Optional[list] = None
161161
track_reference_id: Optional[str] = None
162-
task_id: Optional[str] = None
162+
_task_id: Optional[str] = field(default=None, repr=False)
163163

164164
def __post_init__(self):
165165
self.metadata = self.metadata if self.metadata else {}
@@ -180,7 +180,7 @@ def from_json(cls, payload: dict):
180180
metadata=payload.get(METADATA_KEY, {}),
181181
embedding_vector=payload.get(EMBEDDING_VECTOR_KEY, None),
182182
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
183-
task_id=payload.get(TASK_ID_KEY, None),
183+
_task_id=payload.get(TASK_ID_KEY, None),
184184
)
185185

186186
def to_payload(self) -> dict:
@@ -198,7 +198,7 @@ def to_payload(self) -> dict:
198198
METADATA_KEY: self.metadata,
199199
EMBEDDING_VECTOR_KEY: self.embedding_vector,
200200
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
201-
TASK_ID_KEY: self.task_id,
201+
TASK_ID_KEY: self._task_id,
202202
}
203203

204204
def __eq__(self, other):
@@ -213,7 +213,7 @@ def __eq__(self, other):
213213
and sorted(self.metadata.items()) == sorted(other.metadata.items())
214214
and self.embedding_vector == other.embedding_vector
215215
and self.track_reference_id == other.track_reference_id
216-
and self.task_id == other.task_id
216+
and self._task_id == other._task_id
217217
)
218218

219219

@@ -280,7 +280,7 @@ class LineAnnotation(Annotation):
280280
annotation_id: Optional[str] = None
281281
metadata: Optional[Dict] = None
282282
track_reference_id: Optional[str] = None
283-
task_id: Optional[str] = None
283+
_task_id: Optional[str] = field(default=None, repr=False)
284284

285285
def __post_init__(self):
286286
self.metadata = self.metadata if self.metadata else {}
@@ -310,7 +310,7 @@ def from_json(cls, payload: dict):
310310
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
311311
metadata=payload.get(METADATA_KEY, {}),
312312
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
313-
task_id=payload.get(TASK_ID_KEY, None),
313+
_task_id=payload.get(TASK_ID_KEY, None),
314314
)
315315

316316
def to_payload(self) -> dict:
@@ -324,7 +324,7 @@ def to_payload(self) -> dict:
324324
ANNOTATION_ID_KEY: self.annotation_id,
325325
METADATA_KEY: self.metadata,
326326
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
327-
TASK_ID_KEY: self.task_id,
327+
TASK_ID_KEY: self._task_id,
328328
}
329329
return payload
330330

@@ -375,7 +375,7 @@ class PolygonAnnotation(Annotation):
375375
metadata: Optional[Dict] = None
376376
embedding_vector: Optional[list] = None
377377
track_reference_id: Optional[str] = None
378-
task_id: Optional[str] = None
378+
_task_id: Optional[str] = field(default=None, repr=False)
379379

380380
def __post_init__(self):
381381
self.metadata = self.metadata if self.metadata else {}
@@ -406,7 +406,7 @@ def from_json(cls, payload: dict):
406406
metadata=payload.get(METADATA_KEY, {}),
407407
embedding_vector=payload.get(EMBEDDING_VECTOR_KEY, None),
408408
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
409-
task_id=payload.get(TASK_ID_KEY, None),
409+
_task_id=payload.get(TASK_ID_KEY, None),
410410
)
411411

412412
def to_payload(self) -> dict:
@@ -421,7 +421,7 @@ def to_payload(self) -> dict:
421421
METADATA_KEY: self.metadata,
422422
EMBEDDING_VECTOR_KEY: self.embedding_vector,
423423
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
424-
TASK_ID_KEY: self.task_id,
424+
TASK_ID_KEY: self._task_id,
425425
}
426426
return payload
427427

@@ -518,7 +518,7 @@ class KeypointsAnnotation(Annotation):
518518
annotation_id: Optional[str] = None
519519
metadata: Optional[Dict] = None
520520
track_reference_id: Optional[str] = None
521-
task_id: Optional[str] = None
521+
_task_id: Optional[str] = field(default=None, repr=False)
522522

523523
def __post_init__(self):
524524
self.metadata = self.metadata or {}
@@ -571,7 +571,7 @@ def from_json(cls, payload: dict):
571571
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
572572
metadata=payload.get(METADATA_KEY, {}),
573573
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
574-
task_id=payload.get(TASK_ID_KEY, None),
574+
_task_id=payload.get(TASK_ID_KEY, None),
575575
)
576576

577577
def to_payload(self) -> dict:
@@ -587,7 +587,7 @@ def to_payload(self) -> dict:
587587
ANNOTATION_ID_KEY: self.annotation_id,
588588
METADATA_KEY: self.metadata,
589589
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
590-
TASK_ID_KEY: self.task_id,
590+
TASK_ID_KEY: self._task_id,
591591
}
592592
return payload
593593

@@ -692,7 +692,7 @@ class CuboidAnnotation(Annotation): # pylint: disable=R0902
692692
annotation_id: Optional[str] = None
693693
metadata: Optional[Dict] = None
694694
track_reference_id: Optional[str] = None
695-
task_id: Optional[str] = None
695+
_task_id: Optional[str] = field(default=None, repr=False)
696696

697697
def __post_init__(self):
698698
self.metadata = self.metadata if self.metadata else {}
@@ -709,7 +709,7 @@ def from_json(cls, payload: dict):
709709
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
710710
metadata=payload.get(METADATA_KEY, {}),
711711
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
712-
task_id=payload.get(TASK_ID_KEY, None),
712+
_task_id=payload.get(TASK_ID_KEY, None),
713713
)
714714

715715
def to_payload(self) -> dict:
@@ -729,7 +729,8 @@ def to_payload(self) -> dict:
729729
payload[METADATA_KEY] = self.metadata
730730
if self.track_reference_id:
731731
payload[TRACK_REFERENCE_ID_KEY] = self.track_reference_id
732-
732+
if self._task_id:
733+
payload[TASK_ID_KEY] = self._task_id
733734
return payload
734735

735736

@@ -942,7 +943,7 @@ class CategoryAnnotation(Annotation):
942943
taxonomy_name: Optional[str] = None
943944
metadata: Optional[Dict] = None
944945
track_reference_id: Optional[str] = None
945-
task_id: Optional[str] = None
946+
_task_id: Optional[str] = field(default=None, repr=False)
946947

947948
def __post_init__(self):
948949
self.metadata = self.metadata if self.metadata else {}
@@ -955,7 +956,7 @@ def from_json(cls, payload: dict):
955956
taxonomy_name=payload.get(TAXONOMY_NAME_KEY, None),
956957
metadata=payload.get(METADATA_KEY, {}),
957958
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
958-
task_id=payload.get(TASK_ID_KEY, None),
959+
_task_id=payload.get(TASK_ID_KEY, None),
959960
)
960961

961962
def to_payload(self) -> dict:
@@ -966,7 +967,7 @@ def to_payload(self) -> dict:
966967
REFERENCE_ID_KEY: self.reference_id,
967968
METADATA_KEY: self.metadata,
968969
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
969-
TASK_ID_KEY: self.task_id,
970+
TASK_ID_KEY: self._task_id,
970971
}
971972
if self.taxonomy_name is not None:
972973
payload[TAXONOMY_NAME_KEY] = self.taxonomy_name
@@ -982,7 +983,7 @@ class MultiCategoryAnnotation(Annotation):
982983
taxonomy_name: Optional[str] = None
983984
metadata: Optional[Dict] = None
984985
track_reference_id: Optional[str] = None
985-
task_id: Optional[str] = None
986+
_task_id: Optional[str] = field(default=None, repr=False)
986987

987988
def __post_init__(self):
988989
self.metadata = self.metadata if self.metadata else {}
@@ -995,7 +996,7 @@ def from_json(cls, payload: dict):
995996
taxonomy_name=payload.get(TAXONOMY_NAME_KEY, None),
996997
metadata=payload.get(METADATA_KEY, {}),
997998
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
998-
task_id=payload.get(TASK_ID_KEY, None),
999+
_task_id=payload.get(TASK_ID_KEY, None),
9991000
)
10001001

10011002
def to_payload(self) -> dict:
@@ -1006,7 +1007,7 @@ def to_payload(self) -> dict:
10061007
REFERENCE_ID_KEY: self.reference_id,
10071008
METADATA_KEY: self.metadata,
10081009
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
1009-
TASK_ID_KEY: self.task_id,
1010+
TASK_ID_KEY: self._task_id,
10101011
}
10111012
if self.taxonomy_name is not None:
10121013
payload[TAXONOMY_NAME_KEY] = self.taxonomy_name
@@ -1045,6 +1046,7 @@ class SceneCategoryAnnotation(Annotation):
10451046
reference_id: str
10461047
taxonomy_name: Optional[str] = None
10471048
metadata: Optional[Dict] = field(default_factory=dict)
1049+
_task_id: Optional[str] = field(default=None, repr=False)
10481050

10491051
@classmethod
10501052
def from_json(cls, payload: dict):
@@ -1053,6 +1055,7 @@ def from_json(cls, payload: dict):
10531055
reference_id=payload[REFERENCE_ID_KEY],
10541056
taxonomy_name=payload.get(TAXONOMY_NAME_KEY, None),
10551057
metadata=payload.get(METADATA_KEY, {}),
1058+
_task_id=payload.get(TASK_ID_KEY, None),
10561059
)
10571060

10581061
def to_payload(self) -> dict:
@@ -1062,6 +1065,7 @@ def to_payload(self) -> dict:
10621065
GEOMETRY_KEY: {},
10631066
REFERENCE_ID_KEY: self.reference_id,
10641067
METADATA_KEY: self.metadata,
1068+
TASK_ID_KEY: self._task_id,
10651069
}
10661070
if self.taxonomy_name is not None:
10671071
payload[TAXONOMY_NAME_KEY] = self.taxonomy_name
@@ -1079,9 +1083,7 @@ class AnnotationList:
10791083
default_factory=list
10801084
)
10811085
cuboid_annotations: List[CuboidAnnotation] = field(default_factory=list)
1082-
category_annotations: List[CategoryAnnotation] = field(
1083-
default_factory=list
1084-
)
1086+
category_annotations: List[CategoryAnnotation] = field(default_factory=list)
10851087
multi_category_annotations: List[MultiCategoryAnnotation] = field(
10861088
default_factory=list
10871089
)

nucleus/annotation_uploader.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,7 @@ def get_form_data_and_file_pointers_fn(
176176
"""
177177

178178
def fn():
179-
request_json = construct_segmentation_payload(
180-
segmentations, update
181-
)
179+
request_json = construct_segmentation_payload(segmentations, update)
182180
form_data = [
183181
FileFormField(
184182
name=SERIALIZED_REQUEST_KEY,
@@ -212,15 +210,17 @@ def fn():
212210

213211
return fn
214212

215-
@staticmethod
216-
def check_for_duplicate_ids(annotations: Iterable[Annotation]):
213+
def check_for_duplicate_ids(self, annotations: Iterable[Annotation]):
217214
"""Do not allow annotations to have the same (annotation_id, reference_id, task_id) tuple"""
218215

219-
# some annotations like CategoryAnnotation do not have annotation_id attribute, and as such, we allow duplicates
220216
tuple_ids = [
221-
(ann.reference_id, ann.annotation_id, ann.task_id) # type: ignore
217+
(
218+
ann.reference_id,
219+
ann.annotation_id,
220+
getattr(ann, "_task_id", None),
221+
)
222222
for ann in annotations
223-
if hasattr(ann, "annotation_id") and hasattr(ann, "task_id")
223+
if hasattr(ann, "annotation_id")
224224
]
225225
tuple_count = Counter(tuple_ids)
226226
duplicates = {key for key, value in tuple_count.items() if value > 1}
@@ -255,3 +255,20 @@ def __init__(
255255
self._route = (
256256
f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
257257
)
258+
259+
def check_for_duplicate_ids(self, annotations: Iterable[Annotation]):
260+
"""Do not allow predictions to have the same (annotation_id, reference_id) tuple"""
261+
tuple_ids = [
262+
(pred.annotation_id, pred.reference_id) # type: ignore
263+
for pred in annotations
264+
if hasattr(pred, "annotation_id") and hasattr(pred, "reference_id")
265+
]
266+
tuple_count = Counter(tuple_ids)
267+
duplicates = {key for key, value in tuple_count.items() if value > 1}
268+
if len(duplicates) > 0:
269+
raise DuplicateIDError(
270+
f"Duplicate predictions with the same (annotation_id, reference_id) properties found.\n"
271+
f"Duplicates: {duplicates}\n"
272+
f"To fix this, avoid duplicate predictions, or specify a different annotation_id attribute "
273+
f"for the failing items."
274+
)

tests/test_annotation.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,7 @@ def test_polygon_gt_upload(dataset):
193193
assert response["annotations_processed"] == 1
194194
assert response["annotations_ignored"] == 0
195195

196-
response = dataset.refloc(annotation.reference_id)["annotations"][
197-
"polygon"
198-
]
196+
response = dataset.refloc(annotation.reference_id)["annotations"]["polygon"]
199197
assert len(response) == 1
200198
response_annotation = response[0]
201199
assert_polygon_annotation_matches_dict(
@@ -370,7 +368,7 @@ def test_mixed_annotation_upload(dataset):
370368

371369

372370
def test_box_gt_upload_update(dataset):
373-
TEST_BOX_ANNOTATIONS[0]["task_id"] = "test_task_id"
371+
TEST_BOX_ANNOTATIONS[0]["_task_id"] = "test_task_id"
374372
annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
375373
response = dataset.annotate(annotations=[annotation])
376374

@@ -384,7 +382,7 @@ def test_box_gt_upload_update(dataset):
384382
annotation_update_params["reference_id"] = TEST_BOX_ANNOTATIONS[0][
385383
"reference_id"
386384
]
387-
annotation_update_params["task_id"] = TEST_BOX_ANNOTATIONS[0]["task_id"]
385+
annotation_update_params["_task_id"] = TEST_BOX_ANNOTATIONS[0]["_task_id"]
388386

389387
annotation_update = BoxAnnotation(**annotation_update_params)
390388
response = dataset.annotate(annotations=[annotation_update], update=True)
@@ -401,7 +399,7 @@ def test_box_gt_upload_update(dataset):
401399

402400

403401
def test_box_gt_upload_ignore(dataset):
404-
TEST_BOX_ANNOTATIONS[0]["task_id"] = "test_task_id"
402+
TEST_BOX_ANNOTATIONS[0]["_task_id"] = "test_task_id"
405403
annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
406404

407405
print(annotation)
@@ -418,7 +416,7 @@ def test_box_gt_upload_ignore(dataset):
418416
annotation_update_params["reference_id"] = TEST_BOX_ANNOTATIONS[0][
419417
"reference_id"
420418
]
421-
annotation_update_params["task_id"] = TEST_BOX_ANNOTATIONS[0]["task_id"]
419+
annotation_update_params["_task_id"] = TEST_BOX_ANNOTATIONS[0]["_task_id"]
422420
annotation_update = BoxAnnotation(**annotation_update_params)
423421

424422
# Default behavior is ignore.
@@ -450,19 +448,15 @@ def test_polygon_gt_upload_update(dataset):
450448
annotation_update_params["reference_id"] = TEST_POLYGON_ANNOTATIONS[0][
451449
"reference_id"
452450
]
453-
annotation_update_params["task_id"] = TEST_POLYGON_ANNOTATIONS[0][
454-
"task_id"
455-
]
451+
annotation_update_params["task_id"] = TEST_POLYGON_ANNOTATIONS[0]["task_id"]
456452

457453
annotation_update = PolygonAnnotation.from_json(annotation_update_params)
458454
response = dataset.annotate(annotations=[annotation_update], update=True)
459455

460456
assert response["annotations_processed"] == 1
461457
assert response["annotations_ignored"] == 0
462458

463-
response = dataset.refloc(annotation.reference_id)["annotations"][
464-
"polygon"
465-
]
459+
response = dataset.refloc(annotation.reference_id)["annotations"]["polygon"]
466460
assert len(response) == 1
467461
response_annotation = response[0]
468462
assert_polygon_annotation_matches_dict(
@@ -485,9 +479,7 @@ def test_polygon_gt_upload_ignore(dataset):
485479
annotation_update_params["reference_id"] = TEST_POLYGON_ANNOTATIONS[0][
486480
"reference_id"
487481
]
488-
annotation_update_params["task_id"] = TEST_POLYGON_ANNOTATIONS[0][
489-
"task_id"
490-
]
482+
annotation_update_params["task_id"] = TEST_POLYGON_ANNOTATIONS[0]["task_id"]
491483

492484
annotation_update = PolygonAnnotation.from_json(annotation_update_params)
493485
# Default behavior is ignore.
@@ -496,9 +488,7 @@ def test_polygon_gt_upload_ignore(dataset):
496488
assert response["annotations_processed"] == 0
497489
assert response["annotations_ignored"] == 1
498490

499-
response = dataset.refloc(annotation.reference_id)["annotations"][
500-
"polygon"
501-
]
491+
response = dataset.refloc(annotation.reference_id)["annotations"]["polygon"]
502492
assert len(response) == 1
503493
response_annotation = response[0]
504494
assert_polygon_annotation_matches_dict(

0 commit comments

Comments
 (0)