Skip to content

Commit

Permalink
SQL Index Dataset
Browse files Browse the repository at this point in the history
Summary:
Moving SQL dataset to PyTorch3D. It has been extensively tested in pixar_replay.

It requires SQLAlchemy 2.0, which is not supported in fbcode. So I exclude the sources and tests that depend on it from buck TARGETS.

Reviewed By: bottler

Differential Revision: D45086611

fbshipit-source-id: 0285f03e5824c0478c70ad13731525bb5ec7deef
  • Loading branch information
shapovalov authored and facebook-github-bot committed Apr 25, 2023
1 parent 7aeedd1 commit 32e1992
Show file tree
Hide file tree
Showing 10 changed files with 2,309 additions and 6 deletions.
36 changes: 30 additions & 6 deletions pytorch3d/implicitron/dataset/frame_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def build(
self,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
load_blobs: bool = True,
) -> FrameDataSubtype:
"""An abstract method to build the frame data based on raw frame/sequence
annotations, load the binary data and adjust them according to the metadata.
Expand All @@ -465,8 +466,9 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
Beware that modifications of frame data are done in-place.
Args:
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
dataset_root: The root folder of the dataset; all paths in frame / sequence
annotations are defined w.r.t. this root. Has to be set if any of the
load_* flabs below is true.
load_images: Enable loading the frame RGB data.
load_depths: Enable loading the frame depth maps.
load_depth_masks: Enable loading the frame depth map masks denoting the
Expand Down Expand Up @@ -494,7 +496,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
path_manager: Optionally a PathManager for interpreting paths in a special way.
"""

dataset_root: str = ""
dataset_root: Optional[str] = None
load_images: bool = True
load_depths: bool = True
load_depth_masks: bool = True
Expand All @@ -510,6 +512,25 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
box_crop_context: float = 0.3
path_manager: Any = None

def __post_init__(self) -> None:
load_any_blob = (
self.load_images
or self.load_depths
or self.load_depth_masks
or self.load_masks
or self.load_point_clouds
)
if load_any_blob and self.dataset_root is None:
raise ValueError(
"dataset_root must be set to load any blob data. "
"Make sure it is set in either FrameDataBuilder or Dataset params."
)

if load_any_blob and not os.path.isdir(self.dataset_root): # pyre-ignore
raise ValueError(
f"dataset_root is passed but {self.dataset_root} does not exist."
)

def build(
self,
frame_annotation: types.FrameAnnotation,
Expand Down Expand Up @@ -567,7 +588,7 @@ def build(
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)

frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)

if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
Expand Down Expand Up @@ -612,7 +633,8 @@ def build(
def _load_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[np.ndarray, str]:
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
assert self.dataset_root is not None and entry.mask is not None
full_path = os.path.join(self.dataset_root, entry.mask.path)
fg_probability = load_mask(self._local_path(full_path))
if fg_probability.shape[-2:] != entry.image.size:
raise ValueError(
Expand Down Expand Up @@ -647,7 +669,7 @@ def _load_mask_depth(
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert entry_depth is not None
assert self.dataset_root is not None and entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)

Expand All @@ -657,6 +679,7 @@ def _load_mask_depth(

if self.load_depth_masks:
assert entry_depth.mask_path is not None
# pyre-ignore
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = load_depth_mask(self._local_path(mask_path))
else:
Expand Down Expand Up @@ -705,6 +728,7 @@ def _fix_point_cloud_path(self, path: str) -> str:
)
if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :]
assert self.dataset_root is not None
return os.path.join(self.dataset_root, path)

def _local_path(self, path: str) -> str:
Expand Down
161 changes: 161 additions & 0 deletions pytorch3d/implicitron/dataset/orm_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# This functionality requires SQLAlchemy 2.0 or later.

import math
import struct
from typing import Optional, Tuple

import numpy as np

from pytorch3d.implicitron.dataset.types import (
DepthAnnotation,
ImageAnnotation,
MaskAnnotation,
PointCloudAnnotation,
VideoAnnotation,
ViewpointAnnotation,
)

from sqlalchemy import LargeBinary
from sqlalchemy.orm import (
composite,
DeclarativeBase,
Mapped,
mapped_column,
MappedAsDataclass,
)
from sqlalchemy.types import TypeDecorator


# these produce policies to serialize structured types to blobs
def ArrayTypeFactory(shape):
class NumpyArrayType(TypeDecorator):
impl = LargeBinary

def process_bind_param(self, value, dialect):
if value is not None:
if value.shape != shape:
raise ValueError(f"Passed an array of wrong shape: {value.shape}")
return value.astype(np.float32).tobytes()
return None

def process_result_value(self, value, dialect):
if value is not None:
return np.frombuffer(value, dtype=np.float32).reshape(shape)
return None

return NumpyArrayType


def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)):
format_symbol = {
float: "f", # float32
int: "i", # int32
}[dtype]

class TupleType(TypeDecorator):
impl = LargeBinary
_format = format_symbol * math.prod(shape)

def process_bind_param(self, value, _):
if value is None:
return None

if len(shape) > 1:
value = np.array(value, dtype=dtype).reshape(-1)

return struct.pack(TupleType._format, *value)

def process_result_value(self, value, _):
if value is None:
return None

loaded = struct.unpack(TupleType._format, value)
if len(shape) > 1:
loaded = _rec_totuple(
np.array(loaded, dtype=dtype).reshape(shape).tolist()
)

return loaded

return TupleType


def _rec_totuple(t):
if isinstance(t, list):
return tuple(_rec_totuple(x) for x in t)

return t


class Base(MappedAsDataclass, DeclarativeBase):
"""subclasses will be converted to dataclasses"""


class SqlFrameAnnotation(Base):
__tablename__ = "frame_annots"

sequence_name: Mapped[str] = mapped_column(primary_key=True)
frame_number: Mapped[int] = mapped_column(primary_key=True)
frame_timestamp: Mapped[float] = mapped_column(index=True)

image: Mapped[ImageAnnotation] = composite(
mapped_column("_image_path"),
mapped_column("_image_size", TupleTypeFactory(int)),
)

depth: Mapped[DepthAnnotation] = composite(
mapped_column("_depth_path", nullable=True),
mapped_column("_depth_scale_adjustment", nullable=True),
mapped_column("_depth_mask_path", nullable=True),
)

mask: Mapped[MaskAnnotation] = composite(
mapped_column("_mask_path", nullable=True),
mapped_column("_mask_mass", index=True, nullable=True),
mapped_column(
"_mask_bounding_box_xywh",
TupleTypeFactory(float, shape=(4,)),
nullable=True,
),
)

viewpoint: Mapped[ViewpointAnnotation] = composite(
mapped_column(
"_viewpoint_R", TupleTypeFactory(float, shape=(3, 3)), nullable=True
),
mapped_column(
"_viewpoint_T", TupleTypeFactory(float, shape=(3,)), nullable=True
),
mapped_column(
"_viewpoint_focal_length", TupleTypeFactory(float), nullable=True
),
mapped_column(
"_viewpoint_principal_point", TupleTypeFactory(float), nullable=True
),
mapped_column("_viewpoint_intrinsics_format", nullable=True),
)


class SqlSequenceAnnotation(Base):
__tablename__ = "sequence_annots"

sequence_name: Mapped[str] = mapped_column(primary_key=True)
category: Mapped[str] = mapped_column(index=True)

video: Mapped[VideoAnnotation] = composite(
mapped_column("_video_path", nullable=True),
mapped_column("_video_length", nullable=True),
)
point_cloud: Mapped[PointCloudAnnotation] = composite(
mapped_column("_point_cloud_path", nullable=True),
mapped_column("_point_cloud_quality_score", nullable=True),
mapped_column("_point_cloud_n_points", nullable=True),
)
# the bigger the better
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)
Loading

0 comments on commit 32e1992

Please sign in to comment.