1
1
### IMPORTS ###
2
-
3
2
import b3d
4
3
import b3d .chisight .gen3d .model
5
4
import b3d .chisight .gen3d .transition_kernels as transition_kernels
6
5
import jax
7
6
import jax .numpy as jnp
8
7
from b3d import Pose
8
+ from b3d .chisight .gen3d .model import (
9
+ make_colors_choicemap ,
10
+ make_depth_nonreturn_prob_choicemap ,
11
+ make_visibility_prob_choicemap ,
12
+ )
9
13
from genjax import ChoiceMapBuilder as C
10
14
11
15
b3d .rr_init ("test_dynamic_object_model" )
12
16
13
17
14
18
def test_model_no_likelihood ():
19
+ importance = jax .jit (
20
+ b3d .chisight .gen3d .model .dynamic_object_generative_model .importance
21
+ )
15
22
num_vertices = 100
16
23
vertices = jax .random .uniform (
17
24
jax .random .PRNGKey (0 ), (num_vertices , 3 ), minval = - 1 , maxval = 1
@@ -51,12 +58,10 @@ def test_model_no_likelihood():
51
58
}
52
59
53
60
key = jax .random .PRNGKey (0 )
54
- importance = jax .jit (
55
- b3d .chisight .gen3d .model .dynamic_object_generative_model .importance
56
- )
61
+ trace = importance (key , C .n (), (hyperparams , previous_state ))[0 ]
57
62
58
- trace , _ = importance ( key , C . n (), ( hyperparams , previous_state ) )
59
- assert trace . get_score (). shape == ()
63
+ key = jax . random . PRNGKey ( 0 )
64
+ hyperparams , previous_state = trace . get_args ()
60
65
61
66
traces = [trace ]
62
67
for t in range (100 ):
@@ -65,6 +70,7 @@ def test_model_no_likelihood():
65
70
trace , _ = importance (key , C .n (), (hyperparams , previous_state ))
66
71
b3d .chisight .gen3d .model .viz_trace (trace , t )
67
72
traces .append (trace )
73
+
68
74
colors_over_time = jnp .array (
69
75
[trace .get_choices ()["colors" , ...] for trace in traces ]
70
76
)
@@ -74,12 +80,15 @@ def test_model_no_likelihood():
74
80
fig , ax = plt .subplots (4 , 1 , sharex = True , figsize = (10 , 15 ))
75
81
point_index = 0
76
82
77
- fig .suptitle (f"""pose_kernel max_shift: { hyperparams ['pose_kernel' ].max_shift } ,
78
- color_kernel scale: { hyperparams ['color_kernel' ].scale } ,
79
- visibility_prob_kernel resample_probability: { hyperparams ['visibility_prob_kernel' ].resample_probability } ,
80
- depth_nonreturn_prob_kernel resample_probability: { hyperparams ['depth_nonreturn_prob_kernel' ].resample_probability } ,
81
- depth_scale_kernel resample_probability: { hyperparams ['depth_scale_kernel' ].resample_probability } ,
82
- color_scale_kernel resample_probability: { hyperparams ['color_scale_kernel' ].resample_probability } """ )
83
+ fig .suptitle (
84
+ f"""
85
+ pose_kernel max_shift: { hyperparams ['pose_kernel' ].max_shift } ,
86
+ color_kernel scale: { hyperparams ['color_kernel' ].scale } ,
87
+ visibility_prob_kernel resample_probability: { hyperparams ['visibility_prob_kernel' ].resample_probability } ,
88
+ depth_nonreturn_prob_kernel resample_probability: { hyperparams ['depth_nonreturn_prob_kernel' ].resample_probability } ,
89
+ depth_scale_kernel resample_probability: { hyperparams ['depth_scale_kernel' ].resample_probability } ,
90
+ color_scale_kernel resample_probability: { hyperparams ['color_scale_kernel' ].resample_probability } """
91
+ )
83
92
ax [0 ].set_title (f"Color of vertex { point_index } " )
84
93
ax [0 ].plot (colors_over_time [..., point_index , 0 ], color = "r" )
85
94
ax [0 ].plot (colors_over_time [..., point_index , 1 ], color = "g" )
@@ -109,6 +118,30 @@ def test_model_no_likelihood():
109
118
fig .supxlabel ("Time" )
110
119
fig .savefig ("test_gen3d_model.png" )
111
120
121
+ colors = trace .get_choices ()["colors" , ...]
122
+ new_colors = colors + 0.01
123
+ new_colors_choicemap = make_colors_choicemap (new_colors )
124
+ new_trace = trace .update (key , new_colors_choicemap )[0 ]
125
+ assert jnp .allclose (new_trace .get_choices ()["colors" , ...], new_colors )
126
+
127
+ visibility_prob = trace .get_choices ()["visibility_prob" , ...]
128
+ new_visibility_prob = visibility_prob + 0.01
129
+ new_visibility_prob_choicemap = make_visibility_prob_choicemap (new_visibility_prob )
130
+ new_trace = trace .update (key , new_visibility_prob_choicemap )[0 ]
131
+ assert jnp .allclose (
132
+ new_trace .get_choices ()["visibility_prob" , ...], new_visibility_prob
133
+ )
134
+
135
+ depth_nonreturn_prob = trace .get_choices ()["depth_nonreturn_prob" , ...]
136
+ new_depth_nonreturn_prob = depth_nonreturn_prob + 0.01
137
+ new_depth_nonreturn_prob_choicemap = make_depth_nonreturn_prob_choicemap (
138
+ new_depth_nonreturn_prob
139
+ )
140
+ new_trace = trace .update (key , new_depth_nonreturn_prob_choicemap )[0 ]
141
+ assert jnp .allclose (
142
+ new_trace .get_choices ()["depth_nonreturn_prob" , ...], new_depth_nonreturn_prob
143
+ )
144
+
112
145
113
146
if __name__ == "__main__" :
114
147
test_model_no_likelihood ()
0 commit comments