Skip to content

[DE-5241] Add class label export #446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.17.9](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.9) - 2025-03-11

### Added
- Adding `export_class_labels` methods to datasets and slices to extract unique class labels of the annotations in the dataset/slice.

## [0.17.8](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.8) - 2025-01-02

### Added
Expand Down
35 changes: 26 additions & 9 deletions nucleus/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ def items(self) -> List[DatasetItem]:
dataset_item_jsons = response.get(DATASET_ITEMS_KEY, None)

return [
DatasetItem.from_json(item_json)
for item_json in dataset_item_jsons
DatasetItem.from_json(item_json) for item_json in dataset_item_jsons
]

@property
Expand Down Expand Up @@ -699,9 +698,7 @@ def append(
asynchronous
), "In order to avoid timeouts, you must set asynchronous=True when uploading videos."

return self._append_video_scenes(
video_scenes, update, asynchronous
)
return self._append_video_scenes(video_scenes, update, asynchronous)

if len(dataset_items) > WARN_FOR_LARGE_UPLOAD and not asynchronous:
print(
Expand Down Expand Up @@ -2361,10 +2358,7 @@ def add_items_from_dir(
)

if len(items) > 0:
if (
len(items) > GLOB_SIZE_THRESHOLD_CHECK
and not skip_size_warning
):
if len(items) > GLOB_SIZE_THRESHOLD_CHECK and not skip_size_warning:
raise Exception(
f"Found over {GLOB_SIZE_THRESHOLD_CHECK} items in {dirname}. If this is intended,"
f" set skip_size_warning=True when calling this function."
Expand Down Expand Up @@ -2411,3 +2405,26 @@ def upload_lidar_semseg_predictions(
route=f"dataset/{self.id}/model/{model.id}/pointcloud/{pointcloud_ref_id}/uploadLSSPrediction",
requests_command=requests.post,
)

def export_class_labels(self, slice_id: Optional[str] = None):
"""Fetches a list of class labels for the dataset.

Args:
slice_id (str | None): The ID of the slice to export class labels for. If None, export class labels for the entire dataset.

Returns:
A list of class labels for the dataset.
"""
if slice_id:
api_payload = self._client.make_request(
payload=None,
route=f"slice/{slice_id}/class_labels",
requests_command=requests.get,
)
else:
api_payload = self._client.make_request(
payload=None,
route=f"dataset/{self.id}/class_labels",
requests_command=requests.get,
)
return api_payload.get("data", [])
13 changes: 10 additions & 3 deletions nucleus/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ def created_at(self) -> Optional[datetime.datetime]:
@property
def pending_job_count(self) -> Optional[int]:
if self._pending_job_count is None:
self._pending_job_count = self.info().get(
"pending_job_count", None
)
self._pending_job_count = self.info().get("pending_job_count", None)
return self._pending_job_count

@classmethod
Expand Down Expand Up @@ -705,6 +703,15 @@ def export_raw_items(self) -> List[Dict[str, str]]:
)
return api_payload

def export_class_labels(self):
"""Fetches a list of class labels for the slice."""
api_payload = self._client.make_request(
payload=None,
route=f"slice/{self.id}/class_labels",
requests_command=requests.get,
)
return api_payload.get("data", [])


def check_annotations_are_in_slice(
annotations: List[Annotation], slice_to_check: Slice
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running

[tool.poetry]
name = "scale-nucleus"
version = "0.17.8"
version = "0.17.9"
description = "The official Python client library for Nucleus, the Data Platform for AI"
license = "MIT"
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]
Expand Down
32 changes: 22 additions & 10 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,9 @@ def test_dataset_append_async(dataset: Dataset):

def test_dataset_append_async_with_local_path(dataset: Dataset):
ds_items = make_dataset_items()
ds_items[
0
].image_location = "/a/fake/local/path/you/can/tell/is/local/but/is/fake"
ds_items[0].image_location = (
"/a/fake/local/path/you/can/tell/is/local/but/is/fake"
)
with pytest.raises(ValueError):
dataset.append(ds_items, asynchronous=True)

Expand Down Expand Up @@ -354,8 +354,7 @@ def test_raises_error_for_duplicate():
]
)
assert (
str(error.value)
== "Duplicate reference IDs found among dataset_items:"
str(error.value) == "Duplicate reference IDs found among dataset_items:"
" {'duplicate': 'Count: 2'}"
)

Expand Down Expand Up @@ -480,9 +479,7 @@ def sort_labelmap(segmentation_annotation):
exported[0][ANNOTATIONS_KEY][SEGMENTATION_TYPE]
) == sort_labelmap(clear_fields(segmentation_annotation))
assert exported[0][ANNOTATIONS_KEY][POLYGON_TYPE][0] == polygon_annotation
assert (
exported[0][ANNOTATIONS_KEY][CATEGORY_TYPE][0] == category_annotation
)
assert exported[0][ANNOTATIONS_KEY][CATEGORY_TYPE][0] == category_annotation
exported[0][ANNOTATIONS_KEY][MULTICATEGORY_TYPE][0].labels = set(
exported[0][ANNOTATIONS_KEY][MULTICATEGORY_TYPE][0].labels
)
Expand Down Expand Up @@ -535,8 +532,7 @@ def test_dataset_item_iterator(dataset):
dataset.append(items)
expected_items = {item.reference_id: item for item in dataset.items}
actual_items = {
item.reference_id: item
for item in dataset.items_generator(page_size=1)
item.reference_id: item for item in dataset.items_generator(page_size=1)
}
for key in expected_items:
assert actual_items[key] == expected_items[key]
Expand Down Expand Up @@ -610,3 +606,19 @@ def test_create_update_dataset_from_dir(CLIENT):
assert dataset_item.reference_id in reference_ids
reference_ids.remove(dataset_item.reference_id)
CLIENT.delete_dataset(dataset.id)


@pytest.mark.integration
def test_dataset_export_class_labels(dataset):
dataset.append(make_dataset_items())
# Create box annotation from the test data
box_annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS[0])
dataset.annotate(annotations=[box_annotation])

# Wait annotations to be uploaded (takes a while)
import time

time.sleep(40)
class_labels = dataset.export_class_labels()
# Compare against just the label from the test annotation
assert class_labels == [box_annotation.label]
49 changes: 38 additions & 11 deletions tests/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def test_slice_create_and_export(dataset):
slice_ref_ids = [item.reference_id for item in ds_items[:1]]
# This test assumes one box annotation per item.
annotations = [
BoxAnnotation.from_json(json_data)
for json_data in TEST_BOX_ANNOTATIONS
BoxAnnotation.from_json(json_data) for json_data in TEST_BOX_ANNOTATIONS
]
# Slice creation
slc = dataset.create_slice(
Expand All @@ -96,17 +95,45 @@ def get_expected_item(reference_id):
for row in slc.items_and_annotations():
reference_id = row[ITEM_KEY].reference_id
assert row[ITEM_KEY] == get_expected_item(reference_id)
assert row[ANNOTATIONS_KEY][BOX_TYPE][
0
] == get_expected_box_annotation(reference_id)
assert row[ANNOTATIONS_KEY][BOX_TYPE][0] == get_expected_box_annotation(
reference_id
)

# test async
for row in slc.items_and_annotation_generator():
reference_id = row[ITEM_KEY].reference_id
assert row[ITEM_KEY] == get_expected_item(reference_id)
assert row[ANNOTATIONS_KEY][BOX_TYPE][
0
] == get_expected_box_annotation(reference_id)
assert row[ANNOTATIONS_KEY][BOX_TYPE][0] == get_expected_box_annotation(
reference_id
)


@pytest.mark.integration
def test_slice_export_class_labels(dataset):
ds_items = dataset.items

slice_ref_ids = [item.reference_id for item in ds_items]
# This test assumes one box annotation per item.
annotations = [
BoxAnnotation.from_json(json_data)
for json_data in TEST_BOX_ANNOTATIONS[:1]
]
# Slice creation
slc = dataset.create_slice(
name=TEST_SLICE_NAME,
reference_ids=slice_ref_ids,
)

dataset.annotate(annotations=annotations)

# Wait annotations to be uploaded (takes a while)
import time

time.sleep(40)
class_labels = slc.export_class_labels()

expected_class_labels = [anno.label for anno in annotations]
assert class_labels == expected_class_labels


# TODO(drake): investigate why this only flakes in circleci
Expand Down Expand Up @@ -140,9 +167,9 @@ def get_expected_item(reference_id):
for row in exported:
reference_id = row[ITEM_KEY].reference_id
assert row[ITEM_KEY] == get_expected_item(reference_id)
assert row[PREDICTIONS_KEY][BOX_TYPE][
0
] == get_expected_box_prediction(reference_id)
assert row[PREDICTIONS_KEY][BOX_TYPE][0] == get_expected_box_prediction(
reference_id
)


def test_slice_append(dataset):
Expand Down