Skip to content

Commit

Permalink
Add 'BoundingBox', 'Pose2D' and 'Pose3D' models (#557)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour authored Nov 7, 2024
1 parent 10c4e2a commit e455180
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datachain.lib import func
from datachain.lib import func, models
from datachain.lib.data_model import DataModel, DataType, is_chain_type
from datachain.lib.dc import C, Column, DataChain, Sys
from datachain.lib.file import (
Expand Down Expand Up @@ -38,5 +38,6 @@
"func",
"is_chain_type",
"metrics",
"models",
"param",
]
3 changes: 0 additions & 3 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from pyarrow.dataset import dataset
from pydantic import Field, field_validator

if TYPE_CHECKING:
from typing_extensions import Self

from datachain.client.fileslice import FileSlice
from datachain.lib.data_model import DataModel
from datachain.lib.utils import DataChainError
Expand Down
5 changes: 5 additions & 0 deletions src/datachain/lib/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import yolo
from .bbox import BBox
from .pose import Pose, Pose3D

__all__ = ["BBox", "Pose", "Pose3D", "yolo"]
45 changes: 45 additions & 0 deletions src/datachain/lib/models/bbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Optional

from pydantic import Field

from datachain.lib.data_model import DataModel


class BBox(DataModel):
"""
A data model for representing bounding boxes.
Attributes:
title (str): The title of the bounding box.
x1 (float): The x-coordinate of the top-left corner of the bounding box.
y1 (float): The y-coordinate of the top-left corner of the bounding box.
x2 (float): The x-coordinate of the bottom-right corner of the bounding box.
y2 (float): The y-coordinate of the bottom-right corner of the bounding box.
The bounding box is defined by two points:
- (x1, y1): The top-left corner of the box.
- (x2, y2): The bottom-right corner of the box.
"""

title: str = Field(default="")
x1: float = Field(default=0)
y1: float = Field(default=0)
x2: float = Field(default=0)
y2: float = Field(default=0)

@staticmethod
def from_xywh(bbox: list[float], title: Optional[str] = None) -> "BBox":
"""
Converts a bounding box in (x, y, width, height) format
to a BBox data model instance.
Args:
bbox (list[float]): A bounding box, represented as a list
of four floats [x, y, width, height].
Returns:
BBox2D: An instance of the BBox data model.
"""
assert len(bbox) == 4, f"Bounding box must have 4 elements, got f{len(bbox)}"
x, y, w, h = bbox
return BBox(title=title or "", x1=x, y1=y, x2=x + w, y2=y + h)
37 changes: 37 additions & 0 deletions src/datachain/lib/models/pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pydantic import Field

from datachain.lib.data_model import DataModel


class Pose(DataModel):
"""
A data model for representing pose keypoints.
Attributes:
x (list[float]): The x-coordinates of the keypoints.
y (list[float]): The y-coordinates of the keypoints.
The keypoints are represented as lists of x and y coordinates, where each index
corresponds to a specific body part.
"""

x: list[float] = Field(default=None)
y: list[float] = Field(default=None)


class Pose3D(DataModel):
"""
A data model for representing 3D pose keypoints.
Attributes:
x (list[float]): The x-coordinates of the keypoints.
y (list[float]): The y-coordinates of the keypoints.
visible (list[float]): The visibility of the keypoints.
The keypoints are represented as lists of x, y, and visibility values,
where each index corresponds to a specific body part.
"""

x: list[float] = Field(default=None)
y: list[float] = Field(default=None)
visible: list[float] = Field(default=None)
39 changes: 39 additions & 0 deletions src/datachain/lib/models/yolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
This module contains the YOLO models.
YOLO stands for "You Only Look Once", a family of object detection models that
are designed to be fast and accurate. The models are trained to detect objects
in images by dividing the image into a grid and predicting the bounding boxes
and class probabilities for each grid cell.
More information about YOLO can be found here:
- https://pjreddie.com/darknet/yolo/
- https://docs.ultralytics.com/
"""


class PoseBodyPart:
"""
An enumeration of body parts for YOLO pose keypoints.
More information about the body parts can be found here:
https://docs.ultralytics.com/tasks/pose/
"""

nose = 0
left_eye = 1
right_eye = 2
left_ear = 3
right_ear = 4
left_shoulder = 5
right_shoulder = 6
left_elbow = 7
right_elbow = 8
left_wrist = 9
right_wrist = 10
left_hip = 11
right_hip = 12
left_knee = 13
right_knee = 14
left_ankle = 15
right_ankle = 16
50 changes: 50 additions & 0 deletions tests/unit/lib/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from datachain.lib import models


def test_bbox():
bbox = models.BBox(title="BBox", x1=0.5, y1=1.5, x2=2.5, y2=3.5)
assert bbox.model_dump() == {
"title": "BBox",
"x1": 0.5,
"y1": 1.5,
"x2": 2.5,
"y2": 3.5,
}


def test_bbox_from_xywh():
bbox = models.BBox.from_xywh([0.5, 1.5, 2.5, 3.5])
assert bbox.model_dump() == {"title": "", "x1": 0.5, "y1": 1.5, "x2": 3, "y2": 5}

bbox = models.BBox.from_xywh([0.5, 1.5, 2.5, 3.5], title="BBox")
assert bbox.model_dump() == {
"title": "BBox",
"x1": 0.5,
"y1": 1.5,
"x2": 3,
"y2": 5,
}


def test_pose():
x = [x * 0.5 for x in range(17)]
y = [y * 1.5 for y in range(17)]
pose = models.Pose(x=x, y=y)
assert pose.model_dump() == {"x": x, "y": y}
assert pose.x[models.yolo.PoseBodyPart.nose] == 0
assert pose.x[models.yolo.PoseBodyPart.left_eye] == 0.5
assert pose.x[models.yolo.PoseBodyPart.right_eye] == 1
assert pose.x[models.yolo.PoseBodyPart.left_ear] == 1.5
assert pose.x[models.yolo.PoseBodyPart.right_ear] == 2
assert pose.x[models.yolo.PoseBodyPart.left_shoulder] == 2.5
assert pose.x[models.yolo.PoseBodyPart.right_shoulder] == 3
assert pose.x[models.yolo.PoseBodyPart.left_elbow] == 3.5
assert pose.x[models.yolo.PoseBodyPart.right_elbow] == 4
assert pose.x[models.yolo.PoseBodyPart.left_wrist] == 4.5
assert pose.x[models.yolo.PoseBodyPart.right_wrist] == 5
assert pose.x[models.yolo.PoseBodyPart.left_hip] == 5.5
assert pose.x[models.yolo.PoseBodyPart.right_hip] == 6
assert pose.x[models.yolo.PoseBodyPart.left_knee] == 6.5
assert pose.x[models.yolo.PoseBodyPart.right_knee] == 7
assert pose.x[models.yolo.PoseBodyPart.left_ankle] == 7.5
assert pose.x[models.yolo.PoseBodyPart.right_ankle] == 8

0 comments on commit e455180

Please sign in to comment.