@@ -840,14 +840,19 @@ def create_logger(id_, parameters, arg):
840840 if arg .coalescent :
841841 models .append ('coalescent' )
842842 if arg .coalescent in COALESCENT_PIECEWISE :
843- models .append ('gmrf' )
844843 models .append (
845844 {
846845 'id' : arg .coalescent ,
847846 'type' : 'JointDistributionModel' ,
848- 'distributions' : ['coalescent' , 'gmrf' ],
847+ 'distributions' : ['coalescent' ],
849848 }
850849 )
850+ if arg .theta_prior == 'eml' :
851+ models [- 1 ]['distributions' ].append ('coalescent.theta.eml' )
852+ models .append ('coalescent.theta.eml' )
853+ else :
854+ models [- 1 ]['distributions' ].append ('gmrf' )
855+ models .append ('gmrf' )
851856
852857 return {
853858 "id" : id_ ,
@@ -877,14 +882,19 @@ def create_sampler(id_, var_id, parameters, arg):
877882 if arg .coalescent :
878883 models .append ('coalescent' )
879884 if arg .coalescent in COALESCENT_PIECEWISE :
880- models .append ('gmrf' )
881885 models .append (
882886 {
883887 'id' : arg .coalescent ,
884888 'type' : 'JointDistributionModel' ,
885- 'distributions' : ['coalescent' , 'gmrf' ],
889+ 'distributions' : ['coalescent' ],
886890 }
887891 )
892+ if arg .theta_prior == 'eml' :
893+ models [- 1 ]['distributions' ].append ('coalescent.theta.eml' )
894+ models .append ('coalescent.theta.eml' )
895+ else :
896+ models [- 1 ]['distributions' ].append ('gmrf' )
897+ models .append ('gmrf' )
888898
889899 return {
890900 "id" : id_ ,
@@ -937,7 +947,7 @@ def build_advi(arg):
937947 if arg .clock is not None and arg .heights == 'ratio' :
938948 jacobians_list .append ('tree' )
939949
940- if arg .coalescent in COALESCENT_PIECEWISE :
950+ if arg .coalescent in COALESCENT_PIECEWISE and arg . theta_prior != 'eml' :
941951 jacobians_list .remove ("coalescent.theta" )
942952
943953 joint_jacobian = {
@@ -984,13 +994,20 @@ def build_advi(arg):
984994 parameters .extend (
985995 (
986996 f'{ branch_model_id } .rates.prior.mean' ,
987- f'{ branch_model_id } .rates.prior.scale' ,
997+ f'{ branch_model_id } .rates.prior.stdev' ,
998+ )
999+ )
1000+ elif arg .clock == 'ncln' :
1001+ parameters .extend (
1002+ (
1003+ f'{ branch_model_id } .location' ,
1004+ f'{ branch_model_id } .scale' ,
9881005 )
9891006 )
9901007 else :
9911008 parameters .append (f"{ branch_model_id } .rate" )
9921009
993- if arg .clock == 'horseshoe' or arg .clock == 'ucln' :
1010+ if arg .clock == 'horseshoe' or arg .clock in ( 'ucln' , 'ncln' ) :
9941011 parameters .append (f'{ branch_model_id } .rates' )
9951012 else :
9961013 parameters = ['tree.blens' ]
@@ -999,7 +1016,11 @@ def build_advi(arg):
9991016 if arg .coalescent_integrated is None :
10001017 parameters .append ("coalescent.theta" )
10011018
1002- if arg .coalescent in COALESCENT_PIECEWISE and not arg .gmrf_integrated :
1019+ if (
1020+ arg .coalescent in COALESCENT_PIECEWISE
1021+ and not arg .gmrf_integrated
1022+ and arg .theta_prior is None
1023+ ):
10031024 parameters .append ('gmrf.precision' )
10041025 elif arg .coalescent == 'exponential' :
10051026 parameters .append ('coalescent.growth' )
0 commit comments