diff --git a/raillabel/format/object.py b/raillabel/format/object.py index a0b9989..6a3eaa3 100644 --- a/raillabel/format/object.py +++ b/raillabel/format/object.py @@ -4,8 +4,16 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING +from uuid import UUID -from raillabel.json_format import JSONObject +from raillabel.json_format import JSONElementDataPointer, JSONFrameInterval, JSONObject + +from ._attributes import _attributes_to_json +from .frame_interval import FrameInterval + +if TYPE_CHECKING: + from .frame import Frame @dataclass @@ -26,3 +34,71 @@ def from_json(cls, json: JSONObject) -> Object: name=json.name, type=json.type, ) + + def to_json(self, object_id: UUID, frames: dict[int, Frame]) -> JSONObject: + """Export this object into the RailLabel JSON format.""" + return JSONObject( + name=self.name, + type=self.type, + frame_intervals=_frame_intervals_to_json(object_id, frames), + object_data_pointers=_object_data_pointers_to_json(object_id, self.type, frames), + ) + + +def _frame_intervals_to_json(object_id: UUID, frames: dict[int, Frame]) -> list[JSONFrameInterval]: + frames_with_this_object = set() + + for frame_id, frame in frames.items(): + for annotation in frame.annotations.values(): + if annotation.object_id == object_id: + frames_with_this_object.add(frame_id) + continue + + return [fi.to_json() for fi in FrameInterval.from_frame_ids(list(frames_with_this_object))] + + +def _object_data_pointers_to_json( + object_id: UUID, object_type: str, frames: dict[int, Frame] +) -> dict[str, JSONElementDataPointer]: + pointers_raw = {} + + for frame_id, frame in frames.items(): + for annotation in [ann for ann in frame.annotations.values() if ann.object_id == object_id]: + annotation_name = annotation.name(object_type) + if annotation_name not in pointers_raw: + pointers_raw[annotation_name] = { + "frame_intervals": set(), + "type": annotation_name.split("__")[1], + "attribute_pointers": {}, + } + + pointers_raw[annotation_name]["frame_intervals"].add(frame_id) # type: ignore + json_attributes = _attributes_to_json(annotation.attributes) + + if json_attributes is None: + continue + + for attribute in json_attributes.boolean: # type: ignore + pointers_raw[annotation_name]["attribute_pointers"][attribute.name] = "boolean" # type: ignore + + for attribute in json_attributes.num: # type: ignore + pointers_raw[annotation_name]["attribute_pointers"][attribute.name] = "num" # type: ignore + + for attribute in json_attributes.text: # type: ignore + pointers_raw[annotation_name]["attribute_pointers"][attribute.name] = "text" # type: ignore + + for attribute in json_attributes.vec: # type: ignore + pointers_raw[annotation_name]["attribute_pointers"][attribute.name] = "vec" # type: ignore + + object_data_pointers = {} + for annotation_name, object_data_pointer in pointers_raw.items(): + object_data_pointers[annotation_name] = JSONElementDataPointer( + type=object_data_pointer["type"], + frame_intervals=[ + fi.to_json() + for fi in FrameInterval.from_frame_ids(list(object_data_pointer["frame_intervals"])) # type: ignore + ], + attribute_pointers=object_data_pointer["attribute_pointers"], + ) + + return object_data_pointers diff --git a/tests/test_raillabel/format/test_object.py b/tests/test_raillabel/format/test_object.py index 9931826..98410b1 100644 --- a/tests/test_raillabel/format/test_object.py +++ b/tests/test_raillabel/format/test_object.py @@ -7,7 +7,7 @@ import pytest -from raillabel.json_format import JSONObject +from raillabel.json_format import JSONObject, JSONFrameInterval, JSONElementDataPointer from raillabel.format import Object # == Fixtures ========================= @@ -26,6 +26,42 @@ def object_person_json() -> JSONObject: return JSONObject( name="person_0032", type="person", + frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)], + object_data_pointers={ + "rgb_middle__bbox__person": JSONElementDataPointer( + frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)], + type="bbox", + attribute_pointers={ + "has_red_hat": "boolean", + "has_green_hat": "boolean", + "number_of_red_clothing_items": "num", + "color_of_hat": "text", + "clothing_items": "vec", + }, + ), + "lidar__cuboid__person": JSONElementDataPointer( + frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)], + type="cuboid", + attribute_pointers={ + "has_red_hat": "boolean", + "has_green_hat": "boolean", + "number_of_red_clothing_items": "num", + "color_of_hat": "text", + "clothing_items": "vec", + }, + ), + "lidar__vec__person": JSONElementDataPointer( + frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)], + type="vec", + attribute_pointers={ + "has_red_hat": "boolean", + "has_green_hat": "boolean", + "number_of_red_clothing_items": "num", + "color_of_hat": "text", + "clothing_items": "vec", + }, + ), + }, ) @@ -76,5 +112,10 @@ def test_from_json__track(object_track, object_track_json): assert actual == object_track +def test_to_json__person(object_person, object_person_json, object_person_id, frame): + actual = object_person.to_json(object_person_id, {1: frame}) + assert actual == object_person_json + + if __name__ == "__main__": - pytest.main([__file__, "-v"]) + pytest.main([__file__, "-vv"])