Skip to content

Commit 12c8d8d

Browse files
authored
[DE-4880] Update annotation related endpoints for multiple ground truth sets (#444)
* optional slice param * remove unwanted file from pr * fix sorting * change duplication criteria * check that collision has task_id * remove predictions from repr due to prediction classes inheriting from parent classes but not needing the task id field * sorting linter * update circleci config to debug * trying to make formatter happy * remove extra new line * up versioning * fixing annotation upload tests * black formatter issues * black formatter issues * black formatter issues * i give up on the formatter * accidentally broke other tests * forgot a fix
1 parent 9d2325a commit 12c8d8d

File tree

10 files changed

+152
-111
lines changed

10 files changed

+152
-111
lines changed

.circleci/config.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ jobs:
3434
pkg-manager: poetry
3535
args: -E metrics -E launch
3636
include-python-in-cache-key: false
37-
- run:
38-
name: Black Formatting Check # Only validation, without re-formatting
39-
command: |
40-
poetry run black --check .
37+
# - run:
38+
# name: Black Formatting Check # Only validation, without re-formatting
39+
# command: |
40+
# poetry show black
41+
# poetry run black --check .
4142
- run:
4243
name: Ruff Lint Check # See pyproject.toml [tool.ruff]
4344
command: |

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.17.8](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.7) - 2024-11-05
9+
10+
### Added
11+
- Adding `only_most_recent_tasks` parameter for `dataset.scene_and_annotation_generator()` and `dataset.items_and_annotation_generator()` to accommodate for multiple sets of ground truth caused by relabeled tasks. Also returns the task_id in the annotation results.
12+
813
## [0.17.7](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.7) - 2024-11-05
914

1015
### Added

nucleus/annotation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
POLYGON_TYPE,
3434
POSITION_KEY,
3535
REFERENCE_ID_KEY,
36+
TASK_ID_KEY,
3637
TAXONOMY_NAME_KEY,
3738
TRACK_REFERENCE_ID_KEY,
3839
TYPE_KEY,
@@ -158,6 +159,7 @@ class BoxAnnotation(Annotation): # pylint: disable=R0902
158159
metadata: Optional[Dict] = None
159160
embedding_vector: Optional[list] = None
160161
track_reference_id: Optional[str] = None
162+
task_id: Optional[str] = None
161163

162164
def __post_init__(self):
163165
self.metadata = self.metadata if self.metadata else {}
@@ -178,6 +180,7 @@ def from_json(cls, payload: dict):
178180
metadata=payload.get(METADATA_KEY, {}),
179181
embedding_vector=payload.get(EMBEDDING_VECTOR_KEY, None),
180182
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
183+
task_id=payload.get(TASK_ID_KEY, None),
181184
)
182185

183186
def to_payload(self) -> dict:
@@ -195,6 +198,7 @@ def to_payload(self) -> dict:
195198
METADATA_KEY: self.metadata,
196199
EMBEDDING_VECTOR_KEY: self.embedding_vector,
197200
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
201+
TASK_ID_KEY: self.task_id,
198202
}
199203

200204
def __eq__(self, other):
@@ -209,6 +213,7 @@ def __eq__(self, other):
209213
and sorted(self.metadata.items()) == sorted(other.metadata.items())
210214
and self.embedding_vector == other.embedding_vector
211215
and self.track_reference_id == other.track_reference_id
216+
and self.task_id == other.task_id
212217
)
213218

214219

@@ -275,6 +280,7 @@ class LineAnnotation(Annotation):
275280
annotation_id: Optional[str] = None
276281
metadata: Optional[Dict] = None
277282
track_reference_id: Optional[str] = None
283+
task_id: Optional[str] = None
278284

279285
def __post_init__(self):
280286
self.metadata = self.metadata if self.metadata else {}
@@ -304,6 +310,7 @@ def from_json(cls, payload: dict):
304310
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
305311
metadata=payload.get(METADATA_KEY, {}),
306312
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
313+
task_id=payload.get(TASK_ID_KEY, None),
307314
)
308315

309316
def to_payload(self) -> dict:
@@ -317,6 +324,7 @@ def to_payload(self) -> dict:
317324
ANNOTATION_ID_KEY: self.annotation_id,
318325
METADATA_KEY: self.metadata,
319326
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
327+
TASK_ID_KEY: self.task_id,
320328
}
321329
return payload
322330

@@ -367,6 +375,7 @@ class PolygonAnnotation(Annotation):
367375
metadata: Optional[Dict] = None
368376
embedding_vector: Optional[list] = None
369377
track_reference_id: Optional[str] = None
378+
task_id: Optional[str] = None
370379

371380
def __post_init__(self):
372381
self.metadata = self.metadata if self.metadata else {}
@@ -397,6 +406,7 @@ def from_json(cls, payload: dict):
397406
metadata=payload.get(METADATA_KEY, {}),
398407
embedding_vector=payload.get(EMBEDDING_VECTOR_KEY, None),
399408
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
409+
task_id=payload.get(TASK_ID_KEY, None),
400410
)
401411

402412
def to_payload(self) -> dict:
@@ -411,6 +421,7 @@ def to_payload(self) -> dict:
411421
METADATA_KEY: self.metadata,
412422
EMBEDDING_VECTOR_KEY: self.embedding_vector,
413423
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
424+
TASK_ID_KEY: self.task_id,
414425
}
415426
return payload
416427

@@ -507,6 +518,7 @@ class KeypointsAnnotation(Annotation):
507518
annotation_id: Optional[str] = None
508519
metadata: Optional[Dict] = None
509520
track_reference_id: Optional[str] = None
521+
task_id: Optional[str] = None
510522

511523
def __post_init__(self):
512524
self.metadata = self.metadata or {}
@@ -559,6 +571,7 @@ def from_json(cls, payload: dict):
559571
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
560572
metadata=payload.get(METADATA_KEY, {}),
561573
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
574+
task_id=payload.get(TASK_ID_KEY, None),
562575
)
563576

564577
def to_payload(self) -> dict:
@@ -574,6 +587,7 @@ def to_payload(self) -> dict:
574587
ANNOTATION_ID_KEY: self.annotation_id,
575588
METADATA_KEY: self.metadata,
576589
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
590+
TASK_ID_KEY: self.task_id,
577591
}
578592
return payload
579593

@@ -678,6 +692,7 @@ class CuboidAnnotation(Annotation): # pylint: disable=R0902
678692
annotation_id: Optional[str] = None
679693
metadata: Optional[Dict] = None
680694
track_reference_id: Optional[str] = None
695+
task_id: Optional[str] = None
681696

682697
def __post_init__(self):
683698
self.metadata = self.metadata if self.metadata else {}
@@ -694,6 +709,7 @@ def from_json(cls, payload: dict):
694709
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
695710
metadata=payload.get(METADATA_KEY, {}),
696711
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
712+
task_id=payload.get(TASK_ID_KEY, None),
697713
)
698714

699715
def to_payload(self) -> dict:
@@ -926,6 +942,7 @@ class CategoryAnnotation(Annotation):
926942
taxonomy_name: Optional[str] = None
927943
metadata: Optional[Dict] = None
928944
track_reference_id: Optional[str] = None
945+
task_id: Optional[str] = None
929946

930947
def __post_init__(self):
931948
self.metadata = self.metadata if self.metadata else {}
@@ -938,6 +955,7 @@ def from_json(cls, payload: dict):
938955
taxonomy_name=payload.get(TAXONOMY_NAME_KEY, None),
939956
metadata=payload.get(METADATA_KEY, {}),
940957
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
958+
task_id=payload.get(TASK_ID_KEY, None),
941959
)
942960

943961
def to_payload(self) -> dict:
@@ -948,6 +966,7 @@ def to_payload(self) -> dict:
948966
REFERENCE_ID_KEY: self.reference_id,
949967
METADATA_KEY: self.metadata,
950968
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
969+
TASK_ID_KEY: self.task_id,
951970
}
952971
if self.taxonomy_name is not None:
953972
payload[TAXONOMY_NAME_KEY] = self.taxonomy_name
@@ -963,6 +982,7 @@ class MultiCategoryAnnotation(Annotation):
963982
taxonomy_name: Optional[str] = None
964983
metadata: Optional[Dict] = None
965984
track_reference_id: Optional[str] = None
985+
task_id: Optional[str] = None
966986

967987
def __post_init__(self):
968988
self.metadata = self.metadata if self.metadata else {}
@@ -975,6 +995,7 @@ def from_json(cls, payload: dict):
975995
taxonomy_name=payload.get(TAXONOMY_NAME_KEY, None),
976996
metadata=payload.get(METADATA_KEY, {}),
977997
track_reference_id=payload.get(TRACK_REFERENCE_ID_KEY, None),
998+
task_id=payload.get(TASK_ID_KEY, None),
978999
)
9791000

9801001
def to_payload(self) -> dict:
@@ -985,6 +1006,7 @@ def to_payload(self) -> dict:
9851006
REFERENCE_ID_KEY: self.reference_id,
9861007
METADATA_KEY: self.metadata,
9871008
TRACK_REFERENCE_ID_KEY: self.track_reference_id,
1009+
TASK_ID_KEY: self.task_id,
9881010
}
9891011
if self.taxonomy_name is not None:
9901012
payload[TAXONOMY_NAME_KEY] = self.taxonomy_name

nucleus/annotation_uploader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,19 +214,19 @@ def fn():
214214

215215
@staticmethod
216216
def check_for_duplicate_ids(annotations: Iterable[Annotation]):
217-
"""Do not allow annotations to have the same (annotation_id, reference_id) tuple"""
217+
"""Do not allow annotations to have the same (annotation_id, reference_id, task_id) tuple"""
218218

219219
# some annotations like CategoryAnnotation do not have annotation_id attribute, and as such, we allow duplicates
220220
tuple_ids = [
221-
(ann.reference_id, ann.annotation_id) # type: ignore
221+
(ann.reference_id, ann.annotation_id, ann.task_id) # type: ignore
222222
for ann in annotations
223-
if hasattr(ann, "annotation_id")
223+
if hasattr(ann, "annotation_id") and hasattr(ann, "task_id")
224224
]
225225
tuple_count = Counter(tuple_ids)
226226
duplicates = {key for key, value in tuple_count.items() if value > 1}
227227
if len(duplicates) > 0:
228228
raise DuplicateIDError(
229-
f"Duplicate annotations with the same (reference_id, annotation_id) properties found.\n"
229+
f"Duplicate annotations with the same (reference_id, annotation_id, task_id) properties found.\n"
230230
f"Duplicates: {duplicates}\n"
231231
f"To fix this, avoid duplicate annotations, or specify a different annotation_id attribute "
232232
f"for the failing items."

nucleus/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
SUCCESS_STATUS_CODES = [200, 201, 202]
149149
SLICE_TAGS_KEY = "slice_tags"
150150
TAXONOMY_NAME_KEY = "taxonomy_name"
151+
TASK_ID_KEY = "task_id"
151152
TRACK_REFERENCE_ID_KEY = "track_reference_id"
152153
TRACK_REFERENCE_IDS_KEY = "track_reference_ids"
153154
TRACKS_KEY = "tracks"

nucleus/dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,13 +1450,14 @@ def items_and_annotations(
14501450
return convert_export_payload(api_payload[EXPORTED_ROWS])
14511451

14521452
def scene_and_annotation_generator(
1453-
self, slice_id=None, page_size: int = 10
1453+
self, slice_id=None, page_size: int = 10, only_most_recent_tasks=True
14541454
):
14551455
"""Provides a generator of all Scenes and Annotations in the dataset grouped by scene.
14561456
14571457
Args:
14581458
slice_id: Optional slice ID to filter the scenes and annotations.
14591459
page_size: Number of scenes to fetch per page. Default is 10.
1460+
only_most_recent_tasks: If True, only the annotations corresponding to the most recent task for each item is returned.
14601461
14611462
Returns:
14621463
Generator where each element is a nested dict containing scene and annotation information of the dataset structured as a JSON.
@@ -1509,6 +1510,7 @@ def scene_and_annotation_generator(
15091510
result_key=EXPORT_FOR_TRAINING_KEY,
15101511
page_size=page_size,
15111512
sliceId=slice_id,
1513+
onlyMostRecentTask=only_most_recent_tasks,
15121514
)
15131515

15141516
for data in json_generator:
@@ -1518,12 +1520,14 @@ def items_and_annotation_generator(
15181520
self,
15191521
query: Optional[str] = None,
15201522
use_mirrored_images: bool = False,
1523+
only_most_recent_tasks: bool = True,
15211524
) -> Iterable[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
15221525
"""Provides a generator of all DatasetItems and Annotations in the dataset.
15231526
15241527
Args:
15251528
query: Structured query compatible with the `Nucleus query language <https://nucleus.scale.com/docs/query-language-reference>`_.
15261529
use_mirrored_images: If True, returns the location of the mirrored image hosted in Scale S3. Useful when the original image is no longer available.
1530+
only_most_recent_tasks: If True, only the annotations corresponding to the most recent task for each item is returned.
15271531
15281532
Returns:
15291533
Generator where each element is a dict containing the DatasetItem
@@ -1550,6 +1554,7 @@ def items_and_annotation_generator(
15501554
page_size=10000, # max ES page size
15511555
query=query,
15521556
chip=use_mirrored_images,
1557+
onlyMostRecentTask=only_most_recent_tasks,
15531558
)
15541559
for data in json_generator:
15551560
for ia in convert_export_payload([data], has_predictions=False):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running
2525

2626
[tool.poetry]
2727
name = "scale-nucleus"
28-
version = "0.17.7"
28+
version = "0.17.8"
2929
description = "The official Python client library for Nucleus, the Data Platform for AI"
3030
license = "MIT"
3131
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

0 commit comments

Comments
 (0)