Skip to content

PixelRGBDDistribution: Additional tests on invalid RGBD values #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/b3d/chisight/gen3d/image_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FullPixelDepthDistribution,
PixelDepthDistribution,
PixelRGBDDistribution,
is_unexplained,
)
from b3d.chisight.gen3d.projection import PixelsPointsAssociation

Expand Down Expand Up @@ -79,9 +80,7 @@ def logpdf(
transformed_points, hyperparams
)
vertex_kernel = self.get_rgbd_vertex_kernel()
observed_rgbd_per_point = observed_rgbd.at[
points_to_pixels.x, points_to_pixels.y
].get(mode="drop", fill_value=-1.0)
observed_rgbd_per_point = points_to_pixels.get_point_rgbds(observed_rgbd)
latent_rgbd_per_point = jnp.concatenate(
(state["colors"], transformed_points[..., 2, None]), axis=-1
)
Expand All @@ -94,11 +93,18 @@ def logpdf(
state["visibility_prob"],
state["depth_nonreturn_prob"],
)
# the pixel kernel does not expect invalid observed_rgbd and will return
# -inf if it is invalid. We need to filter those out here.
# (invalid rgbd could happen when the vertex is projected out of the image)
scores = jnp.where(is_unexplained(observed_rgbd_per_point), 0.0, scores)

return scores.sum()

def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
# Note: The distributions were originally defined for per-pixel computation,
# but they should work for per-vertex computation as well
# but they should work for per-vertex computation as well, except that
# they don't expect observed_rgbd to be invalid, so we need to handle
# that manually.
return PixelRGBDDistribution(
FullPixelColorDistribution(),
FullPixelDepthDistribution(self.near, self.far),
Expand Down
51 changes: 25 additions & 26 deletions src/b3d/chisight/gen3d/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import namedtuple
from functools import partial
from functools import partial, wraps
from typing import NamedTuple

import jax
import jax.numpy as jnp
Expand All @@ -9,38 +9,36 @@
from genjax import UpdateProblemBuilder as U
from jax.random import split

from .inference_moves import (
from b3d.chisight.gen3d.inference_moves import (
propose_other_latents_given_pose,
propose_pose,
)
from .model import (
from b3d.chisight.gen3d.model import (
get_hypers,
get_prev_state,
)


# Use namedtuple rather than dict so we can hash this, and use it as a static arg to a jitted function.
InferenceHyperparams = namedtuple(
"InferenceHyperparams",
[
"n_poses",
"pose_proposal_std",
"pose_proposal_conc",
"effective_color_transition_scale",
],
)
"""
Parameters for the inference algorithm.
- n_poses: Number of poses to propose at each timestep.
- pose_proposal_std: Standard deviation of the position distribution for the pose.
- pose_proposal_conc: Concentration parameter for the orientation distribution for the pose.
- effective_color_transition_scale: This parameter is used in the color proposal.
When the color transition kernel is a laplace, this should be its scale.
When the color transition kernel is a different distribution, set this to something
that would make a laplace transition kernel propose with a somewhat similar spread
to the kernel you are using. (This parameter is used to decide
the size of the proposal in the color proposal, using a simple analysis
we conducted in the laplace case.)
"""
class InferenceHyperparams(NamedTuple):
"""
Parameters for the inference algorithm.
- n_poses: Number of poses to propose at each timestep.
- pose_proposal_std: Standard deviation of the position distribution for the pose.
- pose_proposal_conc: Concentration parameter for the orientation distribution for the pose.
- effective_color_transition_scale: This parameter is used in the color proposal.
When the color transition kernel is a laplace, this should be its scale.
When the color transition kernel is a different distribution, set this to something
that would make a laplace transition kernel propose with a somewhat similar spread
to the kernel you are using. (This parameter is used to decide
the size of the proposal in the color proposal, using a simple analysis
we conducted in the laplace case.)
"""

n_poses: int
pose_proposal_std: float
pose_proposal_conc: float
effective_color_transition_scale: float


@jax.jit
Expand Down Expand Up @@ -108,6 +106,7 @@ def inference_step(key, old_trace, observed_rgbd, inference_hyperparams):
)


@wraps(inference_step)
def inference_step_noweight(*args):
"""
Same as inference_step, but only returns the new trace
Expand Down
2 changes: 1 addition & 1 deletion src/b3d/chisight/gen3d/inference_moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def score_visprob_rgb(visprob, rgb):
.logpdf(
observed_rgbd=observed_rgbd_for_this_vertex,
latent_rgbd=jnp.append(rgb, latent_depth),
rgb_scale=new_state["color_scale"],
color_scale=new_state["color_scale"],
depth_scale=new_state["depth_scale"],
visibility_prob=visprob,
depth_nonreturn_prob=new_state["depth_nonreturn_prob"][vertex_index],
Expand Down
4 changes: 4 additions & 0 deletions src/b3d/chisight/gen3d/pixel_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
FullPixelColorDistribution,
MixturePixelColorDistribution,
PixelColorDistribution,
is_unexplained,
)
from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import (
DEPTH_NONRETURN_VAL,
FullPixelDepthDistribution,
MixturePixelDepthDistribution,
PixelDepthDistribution,
Expand All @@ -12,6 +14,8 @@
from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import PixelRGBDDistribution

__all__ = [
"is_unexplained",
"DEPTH_NONRETURN_VAL",
"FullPixelColorDistribution",
"FullPixelDepthDistribution",
"MixturePixelColorDistribution",
Expand Down
27 changes: 14 additions & 13 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
class PixelRGBDDistribution(genjax.ExactDensity):
"""
Distribution args:
- latent_rgbd: 4-array: RGBD value. (Should be [-1, -1, -1, -1] to indicate no point hits here.)
- rgb_scale: float
- latent_rgbd: 4-array: RGBD value. (a value of [-1, -1, -1, -1] indicates no point hits here.)
- color_scale: float
- depth_scale: float
- visibility_prob: float
- depth_nonreturn_prob: float

The support of the distribution is [0, 1]^3 x ([near, far] + {DEPTH_NONRETURN_VALUE}).

If the logpdf of [-1, -1, -1, -1] is requested, this will return 0.0.
Note that this distribution expects the observed_rgbd to be valid. If an invalid
pixel is observed, the logpdf will return -inf.
"""

color_kernel: PixelColorDistribution
Expand All @@ -30,14 +31,14 @@ def sample(
self,
key: PRNGKey,
latent_rgbd: FloatArray,
rgb_scale,
depth_scale,
visibility_prob,
depth_nonreturn_prob,
color_scale: float,
depth_scale: float,
visibility_prob: float,
depth_nonreturn_prob: float,
) -> FloatArray:
keys = jax.random.split(key, 2)
observed_color = self.color_kernel.sample(
keys[0], latent_rgbd[:3], rgb_scale, visibility_prob
keys[0], latent_rgbd[:3], color_scale, visibility_prob
)
observed_depth = self.depth_kernel.sample(
keys[1], latent_rgbd[3], depth_scale, visibility_prob, depth_nonreturn_prob
Expand All @@ -48,13 +49,13 @@ def logpdf(
self,
observed_rgbd: FloatArray,
latent_rgbd: FloatArray,
rgb_scale,
depth_scale,
visibility_prob,
depth_nonreturn_prob,
color_scale: float,
depth_scale: float,
visibility_prob: float,
depth_nonreturn_prob: float,
) -> float:
color_logpdf = self.color_kernel.logpdf(
observed_rgbd[:3], latent_rgbd[:3], rgb_scale, visibility_prob
observed_rgbd[:3], latent_rgbd[:3], color_scale, visibility_prob
)
depth_logpdf = self.depth_kernel.logpdf(
observed_rgbd[3],
Expand Down
6 changes: 1 addition & 5 deletions src/b3d/chisight/gen3d/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,7 @@ def get_point_rgbds(self, rgbd_image: FloatArray) -> FloatArray:
by indexing into the given image.
Vertices that don't hit a pixel will have a value of (-1, -1, -1, -1).
"""
unfiltered = rgbd_image[self.x, self.y]
invalid_indices = jnp.logical_or(self.x == INVALID_IDX, self.y == INVALID_IDX)
return jnp.where(
invalid_indices[:, None], -jnp.ones_like(unfiltered), unfiltered
)
return rgbd_image.at[self.x, self.y].get(mode="drop", fill_value=-1.0)

def get_point_depths(self, rgbd_image: FloatArray) -> FloatArray:
"""
Expand Down
11 changes: 4 additions & 7 deletions src/b3d/chisight/gen3d/transition_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class LaplaceColorDriftKernel(DriftKernel):
This is a thin wrapper around the truncated_color_laplace distribution to
provide a consistent interface with other drift kernels.

Support: [0.0, 1.0]
Support: [0.0, 1.0]^3
"""

scale: float = Pytree.static()
Expand All @@ -244,16 +244,13 @@ def logpdf(self, new_value: ArrayLike, prev_value: ArrayLike) -> ArrayLike:
@Pytree.dataclass
class LaplaceNotTruncatedColorDriftKernel(DriftKernel):
"""A drift kernel that samples the 3 channels of the color from a specialized
truncated Laplace distribution, centered at the previous color. Values outside
of the bounds will be resampled from a small uniform window at the boundary.
This is a thin wrapper around the truncated_color_laplace distribution to
provide a consistent interface with other drift kernels.
truncated Laplace distribution, centered at the previous color. Values may
go outside of the valid color range ([0.0, 1.0]^3).

Support: [0.0, 1.0]
Support: [-inf, inf]^3
"""

scale: float = Pytree.static()
uniform_window_size: float = Pytree.static(default=_FIXED_COLOR_UNIFORM_WINDOW)

def sample(self, key: PRNGKey, prev_value: ArrayLike) -> ArrayLike:
return genjax.laplace.sample(key, prev_value, self.scale)
Expand Down
44 changes: 38 additions & 6 deletions tests/gen3d/test_pixel_rgbd_kernels.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import jax
import jax.numpy as jnp
import pytest
from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import (
FullPixelColorDistribution,
)
from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import (
from b3d.chisight.gen3d.pixel_kernels import (
DEPTH_NONRETURN_VAL,
FullPixelColorDistribution,
FullPixelDepthDistribution,
PixelRGBDDistribution,
)
from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import PixelRGBDDistribution

near = 0.01
far = 20.0
Expand All @@ -20,7 +18,7 @@
FullPixelDepthDistribution(near, far),
),
(
0.01, # rgb_scale
0.01, # color_scale
0.01, # depth_scale
1 - 0.3, # visibility_prob
0.1, # depth_nonreturn_prob
Expand Down Expand Up @@ -92,3 +90,37 @@ def test_relative_logpdf(kernel_spec):
assert logpdf_5 > logpdf_6
# the score of the pixel should be higher when the rgbd is closer
assert logpdf_5 > logpdf_3


@pytest.mark.parametrize("kernel_spec", sample_kernels_to_test)
def test_invalid_pixel(kernel_spec):
kernel, additional_args = kernel_spec

# Latent value of [-1, -1, -1, -1] indicates no point hits here.
latent_rgbd = -jnp.ones(4)
logpdf_1 = kernel.logpdf(
jnp.array([1.0, 0.5, 0.2, 4.0]), latent_rgbd, *additional_args
)
logpdf_2 = kernel.logpdf(
jnp.array([0.0, 0.0, 0.0, 0.02]), latent_rgbd, *additional_args
)
# the observation does not affect the logpdf
assert logpdf_1 == logpdf_2

logpdf_3 = kernel.logpdf(
jnp.array([1.0, 0.5, 0.2, 4.0]), latent_rgbd, 0.1, 0.4, 0.2, 0.1
)
logpdf_4 = kernel.logpdf(
jnp.array([0.0, 0.0, 0.0, 0.02]), latent_rgbd, 0.3, 0.5, 0.4, 0.2
)
# and the values of the parameters doesn't matter either
assert logpdf_2 == logpdf_3
assert logpdf_3 == logpdf_4

# IMPORTANT: note that, by designed, every pixel should have a valid color,
# and an observation of [-1, -1, -1, -1] is actually not within the support
# of the pixel distribution.
logpdf_5 = kernel.logpdf(
jnp.array([-1.0, -1.0, -1.0, -1.0]), latent_rgbd, *additional_args
)
assert logpdf_5 == -jnp.inf
Loading