|  | 
| 1 |  | -module TestExtUtils | 
| 2 |  | - | 
| 3 |  | -################################################### | 
| 4 |  | -# These used to be in DPPL/src/test_utils.jl ###### | 
| 5 |  | -################################################### | 
|  | 1 | +module TestUtils | 
| 6 | 2 | 
 | 
| 7 | 3 | using AbstractMCMC | 
| 8 | 4 | using DynamicPPL | 
| @@ -1101,123 +1097,4 @@ function DynamicPPL.dot_tilde_observe( | 
| 1101 | 1097 |     return logp * context.mod, vi | 
| 1102 | 1098 | end | 
| 1103 | 1099 | 
 | 
| 1104 |  | - | 
| 1105 |  | - | 
| 1106 |  | -################################################### | 
| 1107 |  | -# These used to be in DPPL/test/test_util.jl ###### | 
| 1108 |  | -################################################### | 
| 1109 |  | - | 
| 1110 |  | -# default model | 
| 1111 |  | -@model function gdemo_d() | 
| 1112 |  | -    s ~ InverseGamma(2, 3) | 
| 1113 |  | -    m ~ Normal(0, sqrt(s)) | 
| 1114 |  | -    1.5 ~ Normal(m, sqrt(s)) | 
| 1115 |  | -    2.0 ~ Normal(m, sqrt(s)) | 
| 1116 |  | -    return s, m | 
| 1117 |  | -end | 
| 1118 |  | -const gdemo_default = gdemo_d() | 
| 1119 |  | - | 
| 1120 |  | -function test_model_ad(model, logp_manual) | 
| 1121 |  | -    vi = VarInfo(model) | 
| 1122 |  | -    x = DynamicPPL.getall(vi) | 
| 1123 |  | - | 
| 1124 |  | -    # Log probabilities using the model. | 
| 1125 |  | -    ℓ = DynamicPPL.LogDensityFunction(model, vi) | 
| 1126 |  | -    logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ) | 
| 1127 |  | - | 
| 1128 |  | -    # Check that both functions return the same values. | 
| 1129 |  | -    lp = logp_manual(x) | 
| 1130 |  | -    @test logp_model(x) ≈ lp | 
| 1131 |  | - | 
| 1132 |  | -    # Gradients based on the manual implementation. | 
| 1133 |  | -    grad = ForwardDiff.gradient(logp_manual, x) | 
| 1134 |  | - | 
| 1135 |  | -    y, back = Tracker.forward(logp_manual, x) | 
| 1136 |  | -    @test Tracker.data(y) ≈ lp | 
| 1137 |  | -    @test Tracker.data(back(1)[1]) ≈ grad | 
| 1138 |  | - | 
| 1139 |  | -    y, back = Zygote.pullback(logp_manual, x) | 
| 1140 |  | -    @test y ≈ lp | 
| 1141 |  | -    @test back(1)[1] ≈ grad | 
| 1142 |  | - | 
| 1143 |  | -    # Gradients based on the model. | 
| 1144 |  | -    @test ForwardDiff.gradient(logp_model, x) ≈ grad | 
| 1145 |  | - | 
| 1146 |  | -    y, back = Tracker.forward(logp_model, x) | 
| 1147 |  | -    @test Tracker.data(y) ≈ lp | 
| 1148 |  | -    @test Tracker.data(back(1)[1]) ≈ grad | 
| 1149 |  | - | 
| 1150 |  | -    y, back = Zygote.pullback(logp_model, x) | 
| 1151 |  | -    @test y ≈ lp | 
| 1152 |  | -    @test back(1)[1] ≈ grad | 
| 1153 |  | -end | 
| 1154 |  | - | 
| 1155 |  | -""" | 
| 1156 |  | -    test_setval!(model, chain; sample_idx = 1, chain_idx = 1) | 
| 1157 |  | -
 | 
| 1158 |  | -Test `setval!` on `model` and `chain`. | 
| 1159 |  | -
 | 
| 1160 |  | -Worth noting that this only supports models containing symbols of the forms | 
| 1161 |  | -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. | 
| 1162 |  | -""" | 
| 1163 |  | -function test_setval!(model, chain; sample_idx=1, chain_idx=1) | 
| 1164 |  | -    var_info = VarInfo(model) | 
| 1165 |  | -    spl = SampleFromPrior() | 
| 1166 |  | -    θ_old = var_info[spl] | 
| 1167 |  | -    DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) | 
| 1168 |  | -    θ_new = var_info[spl] | 
| 1169 |  | -    @test θ_old != θ_new | 
| 1170 |  | -    vals = DynamicPPL.values_as(var_info, OrderedDict) | 
| 1171 |  | -    iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) | 
| 1172 |  | -    for (n, v) in mapreduce(collect, vcat, iters) | 
| 1173 |  | -        n = string(n) | 
| 1174 |  | -        if Symbol(n) ∉ keys(chain) | 
| 1175 |  | -            # Assume it's a group | 
| 1176 |  | -            chain_val = vec( | 
| 1177 |  | -                MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] | 
| 1178 |  | -            ) | 
| 1179 |  | -            v_true = vec(v) | 
| 1180 |  | -        else | 
| 1181 |  | -            chain_val = chain[sample_idx, n, chain_idx] | 
| 1182 |  | -            v_true = v | 
| 1183 |  | -        end | 
| 1184 |  | - | 
| 1185 |  | -        @test v_true == chain_val | 
| 1186 |  | -    end | 
| 1187 |  | -end | 
| 1188 |  | - | 
| 1189 |  | -""" | 
| 1190 |  | -    short_varinfo_name(vi::AbstractVarInfo) | 
| 1191 |  | -
 | 
| 1192 |  | -Return string representing a short description of `vi`. | 
| 1193 |  | -""" | 
| 1194 |  | -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = | 
| 1195 |  | -    "threadsafe($(short_varinfo_name(vi.varinfo)))" | 
| 1196 |  | -function short_varinfo_name(vi::TypedVarInfo) | 
| 1197 |  | -    DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" | 
| 1198 |  | -    return "TypedVarInfo" | 
| 1199 |  | -end | 
| 1200 |  | -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" | 
| 1201 |  | -short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" | 
| 1202 |  | -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" | 
| 1203 |  | -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" | 
| 1204 |  | -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) | 
| 1205 |  | -    return "SimpleVarInfo{<:VarNamedVector}" | 
| 1206 |  | -end | 
| 1207 |  | - | 
| 1208 |  | -# convenient functions for testing model.jl | 
| 1209 |  | -# function to modify the representation of values based on their length | 
| 1210 |  | -function modify_value_representation(nt::NamedTuple) | 
| 1211 |  | -    modified_nt = NamedTuple() | 
| 1212 |  | -    for (key, value) in zip(keys(nt), values(nt)) | 
| 1213 |  | -        if length(value) == 1  # Scalar value | 
| 1214 |  | -            modified_value = value[1] | 
| 1215 |  | -        else  # Non-scalar value | 
| 1216 |  | -            modified_value = value | 
| 1217 |  | -        end | 
| 1218 |  | -        modified_nt = merge(modified_nt, (key => modified_value,)) | 
| 1219 |  | -    end | 
| 1220 |  | -    return modified_nt | 
| 1221 | 1100 | end | 
| 1222 |  | - | 
| 1223 |  | -end  # module TestExtUtils | 
0 commit comments