Skip to content

Commit eb389ce

Browse files
authored
[DE-5241] Add class label export (#446)
* export class labels and test * edit change log * type issues * Revert "type issues" This reverts commit e1a33db. * type issues * more type issues * lower timer and add comments
1 parent 3181b66 commit eb389ce

File tree

6 files changed

+102
-34
lines changed

6 files changed

+102
-34
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.17.9](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.9) - 2025-03-11
9+
10+
### Added
11+
- Adding `export_class_labels` methods to datasets and slices to extract unique class labels of the annotations in the dataset/slice.
12+
813
## [0.17.8](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.8) - 2025-01-02
914

1015
### Added

nucleus/dataset.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,7 @@ def items(self) -> List[DatasetItem]:
332332
dataset_item_jsons = response.get(DATASET_ITEMS_KEY, None)
333333

334334
return [
335-
DatasetItem.from_json(item_json)
336-
for item_json in dataset_item_jsons
335+
DatasetItem.from_json(item_json) for item_json in dataset_item_jsons
337336
]
338337

339338
@property
@@ -699,9 +698,7 @@ def append(
699698
asynchronous
700699
), "In order to avoid timeouts, you must set asynchronous=True when uploading videos."
701700

702-
return self._append_video_scenes(
703-
video_scenes, update, asynchronous
704-
)
701+
return self._append_video_scenes(video_scenes, update, asynchronous)
705702

706703
if len(dataset_items) > WARN_FOR_LARGE_UPLOAD and not asynchronous:
707704
print(
@@ -2361,10 +2358,7 @@ def add_items_from_dir(
23612358
)
23622359

23632360
if len(items) > 0:
2364-
if (
2365-
len(items) > GLOB_SIZE_THRESHOLD_CHECK
2366-
and not skip_size_warning
2367-
):
2361+
if len(items) > GLOB_SIZE_THRESHOLD_CHECK and not skip_size_warning:
23682362
raise Exception(
23692363
f"Found over {GLOB_SIZE_THRESHOLD_CHECK} items in {dirname}. If this is intended,"
23702364
f" set skip_size_warning=True when calling this function."
@@ -2411,3 +2405,26 @@ def upload_lidar_semseg_predictions(
24112405
route=f"dataset/{self.id}/model/{model.id}/pointcloud/{pointcloud_ref_id}/uploadLSSPrediction",
24122406
requests_command=requests.post,
24132407
)
2408+
2409+
def export_class_labels(self, slice_id: Optional[str] = None):
2410+
"""Fetches a list of class labels for the dataset.
2411+
2412+
Args:
2413+
slice_id (str | None): The ID of the slice to export class labels for. If None, export class labels for the entire dataset.
2414+
2415+
Returns:
2416+
A list of class labels for the dataset.
2417+
"""
2418+
if slice_id:
2419+
api_payload = self._client.make_request(
2420+
payload=None,
2421+
route=f"slice/{slice_id}/class_labels",
2422+
requests_command=requests.get,
2423+
)
2424+
else:
2425+
api_payload = self._client.make_request(
2426+
payload=None,
2427+
route=f"dataset/{self.id}/class_labels",
2428+
requests_command=requests.get,
2429+
)
2430+
return api_payload.get("data", [])

nucleus/slice.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,7 @@ def created_at(self) -> Optional[datetime.datetime]:
168168
@property
169169
def pending_job_count(self) -> Optional[int]:
170170
if self._pending_job_count is None:
171-
self._pending_job_count = self.info().get(
172-
"pending_job_count", None
173-
)
171+
self._pending_job_count = self.info().get("pending_job_count", None)
174172
return self._pending_job_count
175173

176174
@classmethod
@@ -705,6 +703,15 @@ def export_raw_items(self) -> List[Dict[str, str]]:
705703
)
706704
return api_payload
707705

706+
def export_class_labels(self):
707+
"""Fetches a list of class labels for the slice."""
708+
api_payload = self._client.make_request(
709+
payload=None,
710+
route=f"slice/{self.id}/class_labels",
711+
requests_command=requests.get,
712+
)
713+
return api_payload.get("data", [])
714+
708715

709716
def check_annotations_are_in_slice(
710717
annotations: List[Annotation], slice_to_check: Slice

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running
2525

2626
[tool.poetry]
2727
name = "scale-nucleus"
28-
version = "0.17.8"
28+
version = "0.17.9"
2929
description = "The official Python client library for Nucleus, the Data Platform for AI"
3030
license = "MIT"
3131
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

tests/test_dataset.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,9 @@ def test_dataset_append_async(dataset: Dataset):
309309

310310
def test_dataset_append_async_with_local_path(dataset: Dataset):
311311
ds_items = make_dataset_items()
312-
ds_items[
313-
0
314-
].image_location = "/a/fake/local/path/you/can/tell/is/local/but/is/fake"
312+
ds_items[0].image_location = (
313+
"/a/fake/local/path/you/can/tell/is/local/but/is/fake"
314+
)
315315
with pytest.raises(ValueError):
316316
dataset.append(ds_items, asynchronous=True)
317317

@@ -354,8 +354,7 @@ def test_raises_error_for_duplicate():
354354
]
355355
)
356356
assert (
357-
str(error.value)
358-
== "Duplicate reference IDs found among dataset_items:"
357+
str(error.value) == "Duplicate reference IDs found among dataset_items:"
359358
" {'duplicate': 'Count: 2'}"
360359
)
361360

@@ -480,9 +479,7 @@ def sort_labelmap(segmentation_annotation):
480479
exported[0][ANNOTATIONS_KEY][SEGMENTATION_TYPE]
481480
) == sort_labelmap(clear_fields(segmentation_annotation))
482481
assert exported[0][ANNOTATIONS_KEY][POLYGON_TYPE][0] == polygon_annotation
483-
assert (
484-
exported[0][ANNOTATIONS_KEY][CATEGORY_TYPE][0] == category_annotation
485-
)
482+
assert exported[0][ANNOTATIONS_KEY][CATEGORY_TYPE][0] == category_annotation
486483
exported[0][ANNOTATIONS_KEY][MULTICATEGORY_TYPE][0].labels = set(
487484
exported[0][ANNOTATIONS_KEY][MULTICATEGORY_TYPE][0].labels
488485
)
@@ -535,8 +532,7 @@ def test_dataset_item_iterator(dataset):
535532
dataset.append(items)
536533
expected_items = {item.reference_id: item for item in dataset.items}
537534
actual_items = {
538-
item.reference_id: item
539-
for item in dataset.items_generator(page_size=1)
535+
item.reference_id: item for item in dataset.items_generator(page_size=1)
540536
}
541537
for key in expected_items:
542538
assert actual_items[key] == expected_items[key]
@@ -610,3 +606,19 @@ def test_create_update_dataset_from_dir(CLIENT):
610606
assert dataset_item.reference_id in reference_ids
611607
reference_ids.remove(dataset_item.reference_id)
612608
CLIENT.delete_dataset(dataset.id)
609+
610+
611+
@pytest.mark.integration
612+
def test_dataset_export_class_labels(dataset):
613+
dataset.append(make_dataset_items())
614+
# Create box annotation from the test data
615+
box_annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
616+
dataset.annotate(annotations=[box_annotation])
617+
618+
# Wait annotations to be uploaded (takes a while)
619+
import time
620+
621+
time.sleep(40)
622+
class_labels = dataset.export_class_labels()
623+
# Compare against just the label from the test annotation
624+
assert class_labels == [box_annotation.label]

tests/test_slice.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ def test_slice_create_and_export(dataset):
7070
slice_ref_ids = [item.reference_id for item in ds_items[:1]]
7171
# This test assumes one box annotation per item.
7272
annotations = [
73-
BoxAnnotation.from_json(json_data)
74-
for json_data in TEST_BOX_ANNOTATIONS
73+
BoxAnnotation.from_json(json_data) for json_data in TEST_BOX_ANNOTATIONS
7574
]
7675
# Slice creation
7776
slc = dataset.create_slice(
@@ -96,17 +95,45 @@ def get_expected_item(reference_id):
9695
for row in slc.items_and_annotations():
9796
reference_id = row[ITEM_KEY].reference_id
9897
assert row[ITEM_KEY] == get_expected_item(reference_id)
99-
assert row[ANNOTATIONS_KEY][BOX_TYPE][
100-
0
101-
] == get_expected_box_annotation(reference_id)
98+
assert row[ANNOTATIONS_KEY][BOX_TYPE][0] == get_expected_box_annotation(
99+
reference_id
100+
)
102101

103102
# test async
104103
for row in slc.items_and_annotation_generator():
105104
reference_id = row[ITEM_KEY].reference_id
106105
assert row[ITEM_KEY] == get_expected_item(reference_id)
107-
assert row[ANNOTATIONS_KEY][BOX_TYPE][
108-
0
109-
] == get_expected_box_annotation(reference_id)
106+
assert row[ANNOTATIONS_KEY][BOX_TYPE][0] == get_expected_box_annotation(
107+
reference_id
108+
)
109+
110+
111+
@pytest.mark.integration
112+
def test_slice_export_class_labels(dataset):
113+
ds_items = dataset.items
114+
115+
slice_ref_ids = [item.reference_id for item in ds_items]
116+
# This test assumes one box annotation per item.
117+
annotations = [
118+
BoxAnnotation.from_json(json_data)
119+
for json_data in TEST_BOX_ANNOTATIONS[:1]
120+
]
121+
# Slice creation
122+
slc = dataset.create_slice(
123+
name=TEST_SLICE_NAME,
124+
reference_ids=slice_ref_ids,
125+
)
126+
127+
dataset.annotate(annotations=annotations)
128+
129+
# Wait annotations to be uploaded (takes a while)
130+
import time
131+
132+
time.sleep(40)
133+
class_labels = slc.export_class_labels()
134+
135+
expected_class_labels = [anno.label for anno in annotations]
136+
assert class_labels == expected_class_labels
110137

111138

112139
# TODO(drake): investigate why this only flakes in circleci
@@ -140,9 +167,9 @@ def get_expected_item(reference_id):
140167
for row in exported:
141168
reference_id = row[ITEM_KEY].reference_id
142169
assert row[ITEM_KEY] == get_expected_item(reference_id)
143-
assert row[PREDICTIONS_KEY][BOX_TYPE][
144-
0
145-
] == get_expected_box_prediction(reference_id)
170+
assert row[PREDICTIONS_KEY][BOX_TYPE][0] == get_expected_box_prediction(
171+
reference_id
172+
)
146173

147174

148175
def test_slice_append(dataset):

0 commit comments

Comments
 (0)