Skip to content

Commit 138d765

Browse files
save tests passing
1 parent 4f93c8f commit 138d765

File tree

2 files changed

+69
-13
lines changed

2 files changed

+69
-13
lines changed

src/b3d/chisight/gen3d/model.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import genjax
2+
import jax
23
import jax.numpy as jnp
34
import rerun as rr
5+
from genjax import ChoiceMapBuilder as C
46

57
import b3d
68

@@ -71,7 +73,28 @@ def dynamic_object_generative_model(hyperparams, previous_state):
7173
}
7274

7375

74-
### Viz ###
76+
### Helpers ###
77+
78+
79+
def make_colors_choicemap(colors):
80+
return jax.vmap(lambda idx: C["colors", idx].set(colors[idx]))(
81+
jnp.arange(len(colors))
82+
)
83+
84+
85+
def make_visibility_prob_choicemap(visibility_prob):
86+
return jax.vmap(lambda idx: C["visibility_prob", idx].set(visibility_prob[idx]))(
87+
jnp.arange(len(visibility_prob))
88+
)
89+
90+
91+
def make_depth_nonreturn_prob_choicemap(depth_nonreturn_prob):
92+
return jax.vmap(
93+
lambda idx: C["depth_nonreturn_prob", idx].set(depth_nonreturn_prob[idx])
94+
)(jnp.arange(len(depth_nonreturn_prob)))
95+
96+
97+
### Visualization Code ###
7598
def viz_trace(trace, t=0, ground_truth_vertices=None, ground_truth_pose=None):
7699
b3d.rr_set_time(t)
77100
hyperparams, _ = trace.get_args()

tests/gen3d/test_model.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
### IMPORTS ###
2-
32
import b3d
43
import b3d.chisight.gen3d.model
54
import b3d.chisight.gen3d.transition_kernels as transition_kernels
65
import jax
76
import jax.numpy as jnp
87
from b3d import Pose
8+
from b3d.chisight.gen3d.model import (
9+
make_colors_choicemap,
10+
make_depth_nonreturn_prob_choicemap,
11+
make_visibility_prob_choicemap,
12+
)
913
from genjax import ChoiceMapBuilder as C
1014

1115
b3d.rr_init("test_dynamic_object_model")
1216

1317

1418
def test_model_no_likelihood():
19+
importance = jax.jit(
20+
b3d.chisight.gen3d.model.dynamic_object_generative_model.importance
21+
)
1522
num_vertices = 100
1623
vertices = jax.random.uniform(
1724
jax.random.PRNGKey(0), (num_vertices, 3), minval=-1, maxval=1
@@ -51,12 +58,10 @@ def test_model_no_likelihood():
5158
}
5259

5360
key = jax.random.PRNGKey(0)
54-
importance = jax.jit(
55-
b3d.chisight.gen3d.model.dynamic_object_generative_model.importance
56-
)
61+
trace = importance(key, C.n(), (hyperparams, previous_state))[0]
5762

58-
trace, _ = importance(key, C.n(), (hyperparams, previous_state))
59-
assert trace.get_score().shape == ()
63+
key = jax.random.PRNGKey(0)
64+
hyperparams, previous_state = trace.get_args()
6065

6166
traces = [trace]
6267
for t in range(100):
@@ -65,6 +70,7 @@ def test_model_no_likelihood():
6570
trace, _ = importance(key, C.n(), (hyperparams, previous_state))
6671
b3d.chisight.gen3d.model.viz_trace(trace, t)
6772
traces.append(trace)
73+
6874
colors_over_time = jnp.array(
6975
[trace.get_choices()["colors", ...] for trace in traces]
7076
)
@@ -74,12 +80,15 @@ def test_model_no_likelihood():
7480
fig, ax = plt.subplots(4, 1, sharex=True, figsize=(10, 15))
7581
point_index = 0
7682

77-
fig.suptitle(f"""pose_kernel max_shift: {hyperparams['pose_kernel'].max_shift},
78-
color_kernel scale: {hyperparams['color_kernel'].scale},
79-
visibility_prob_kernel resample_probability: {hyperparams['visibility_prob_kernel'].resample_probability},
80-
depth_nonreturn_prob_kernel resample_probability: {hyperparams['depth_nonreturn_prob_kernel'].resample_probability},
81-
depth_scale_kernel resample_probability: {hyperparams['depth_scale_kernel'].resample_probability},
82-
color_scale_kernel resample_probability: {hyperparams['color_scale_kernel'].resample_probability}""")
83+
fig.suptitle(
84+
f"""
85+
pose_kernel max_shift: {hyperparams['pose_kernel'].max_shift},
86+
color_kernel scale: {hyperparams['color_kernel'].scale},
87+
visibility_prob_kernel resample_probability: {hyperparams['visibility_prob_kernel'].resample_probability},
88+
depth_nonreturn_prob_kernel resample_probability: {hyperparams['depth_nonreturn_prob_kernel'].resample_probability},
89+
depth_scale_kernel resample_probability: {hyperparams['depth_scale_kernel'].resample_probability},
90+
color_scale_kernel resample_probability: {hyperparams['color_scale_kernel'].resample_probability}"""
91+
)
8392
ax[0].set_title(f"Color of vertex {point_index}")
8493
ax[0].plot(colors_over_time[..., point_index, 0], color="r")
8594
ax[0].plot(colors_over_time[..., point_index, 1], color="g")
@@ -109,6 +118,30 @@ def test_model_no_likelihood():
109118
fig.supxlabel("Time")
110119
fig.savefig("test_gen3d_model.png")
111120

121+
colors = trace.get_choices()["colors", ...]
122+
new_colors = colors + 0.01
123+
new_colors_choicemap = make_colors_choicemap(new_colors)
124+
new_trace = trace.update(key, new_colors_choicemap)[0]
125+
assert jnp.allclose(new_trace.get_choices()["colors", ...], new_colors)
126+
127+
visibility_prob = trace.get_choices()["visibility_prob", ...]
128+
new_visibility_prob = visibility_prob + 0.01
129+
new_visibility_prob_choicemap = make_visibility_prob_choicemap(new_visibility_prob)
130+
new_trace = trace.update(key, new_visibility_prob_choicemap)[0]
131+
assert jnp.allclose(
132+
new_trace.get_choices()["visibility_prob", ...], new_visibility_prob
133+
)
134+
135+
depth_nonreturn_prob = trace.get_choices()["depth_nonreturn_prob", ...]
136+
new_depth_nonreturn_prob = depth_nonreturn_prob + 0.01
137+
new_depth_nonreturn_prob_choicemap = make_depth_nonreturn_prob_choicemap(
138+
new_depth_nonreturn_prob
139+
)
140+
new_trace = trace.update(key, new_depth_nonreturn_prob_choicemap)[0]
141+
assert jnp.allclose(
142+
new_trace.get_choices()["depth_nonreturn_prob", ...], new_depth_nonreturn_prob
143+
)
144+
112145

113146
if __name__ == "__main__":
114147
test_model_no_likelihood()

0 commit comments

Comments
 (0)