Skip to content

Commit e735d8c

Browse files
committed
Implement generalized skyline and arbitrary clock models
1 parent e5731a3 commit e735d8c

34 files changed

+425
-62
lines changed

torchtree/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This is the root package of the torchtree framework."""
2+
23
from ._version import __version__
34
from .core.parameter import CatParameter, Parameter, TransformedParameter, ViewParameter
45

torchtree/cli/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""The cli package contains modules for creating JSON configuration files
22
through a command-line interface."""
3+
34
from torchtree.cli.plugin_manager import PluginManager
45

56
PLUGIN_MANAGER = PluginManager()

torchtree/cli/advi.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -840,14 +840,19 @@ def create_logger(id_, parameters, arg):
840840
if arg.coalescent:
841841
models.append('coalescent')
842842
if arg.coalescent in COALESCENT_PIECEWISE:
843-
models.append('gmrf')
844843
models.append(
845844
{
846845
'id': arg.coalescent,
847846
'type': 'JointDistributionModel',
848-
'distributions': ['coalescent', 'gmrf'],
847+
'distributions': ['coalescent'],
849848
}
850849
)
850+
if arg.theta_prior == 'eml':
851+
models[-1]['distributions'].append('coalescent.theta.eml')
852+
models.append('coalescent.theta.eml')
853+
else:
854+
models[-1]['distributions'].append('gmrf')
855+
models.append('gmrf')
851856

852857
return {
853858
"id": id_,
@@ -877,14 +882,19 @@ def create_sampler(id_, var_id, parameters, arg):
877882
if arg.coalescent:
878883
models.append('coalescent')
879884
if arg.coalescent in COALESCENT_PIECEWISE:
880-
models.append('gmrf')
881885
models.append(
882886
{
883887
'id': arg.coalescent,
884888
'type': 'JointDistributionModel',
885-
'distributions': ['coalescent', 'gmrf'],
889+
'distributions': ['coalescent'],
886890
}
887891
)
892+
if arg.theta_prior == 'eml':
893+
models[-1]['distributions'].append('coalescent.theta.eml')
894+
models.append('coalescent.theta.eml')
895+
else:
896+
models[-1]['distributions'].append('gmrf')
897+
models.append('gmrf')
888898

889899
return {
890900
"id": id_,
@@ -937,7 +947,7 @@ def build_advi(arg):
937947
if arg.clock is not None and arg.heights == 'ratio':
938948
jacobians_list.append('tree')
939949

940-
if arg.coalescent in COALESCENT_PIECEWISE:
950+
if arg.coalescent in COALESCENT_PIECEWISE and arg.theta_prior != 'eml':
941951
jacobians_list.remove("coalescent.theta")
942952

943953
joint_jacobian = {
@@ -984,13 +994,20 @@ def build_advi(arg):
984994
parameters.extend(
985995
(
986996
f'{branch_model_id}.rates.prior.mean',
987-
f'{branch_model_id}.rates.prior.scale',
997+
f'{branch_model_id}.rates.prior.stdev',
998+
)
999+
)
1000+
elif arg.clock == 'ncln':
1001+
parameters.extend(
1002+
(
1003+
f'{branch_model_id}.location',
1004+
f'{branch_model_id}.scale',
9881005
)
9891006
)
9901007
else:
9911008
parameters.append(f"{branch_model_id}.rate")
9921009

993-
if arg.clock == 'horseshoe' or arg.clock == 'ucln':
1010+
if arg.clock == 'horseshoe' or arg.clock in ('ucln', 'ncln'):
9941011
parameters.append(f'{branch_model_id}.rates')
9951012
else:
9961013
parameters = ['tree.blens']
@@ -999,7 +1016,11 @@ def build_advi(arg):
9991016
if arg.coalescent_integrated is None:
10001017
parameters.append("coalescent.theta")
10011018

1002-
if arg.coalescent in COALESCENT_PIECEWISE and not arg.gmrf_integrated:
1019+
if (
1020+
arg.coalescent in COALESCENT_PIECEWISE
1021+
and not arg.gmrf_integrated
1022+
and arg.theta_prior is None
1023+
):
10031024
parameters.append('gmrf.precision')
10041025
elif arg.coalescent == 'exponential':
10051026
parameters.append('coalescent.growth')

torchtree/cli/argparse_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@ def str_or_float(arg, choices):
3131
)
3232

3333

34+
def str_or_int(arg):
35+
"""Used by argparse when the argument can be either an integer or a string."""
36+
try:
37+
return int(arg)
38+
except ValueError:
39+
if isinstance(arg, str):
40+
return arg
41+
else:
42+
raise argparse.ArgumentTypeError(
43+
'invalid choice (choose from an integer or a string)'
44+
)
45+
46+
3447
def list_of_float(arg, length):
3548
"""Used by argparse when the argument should be a list of floats."""
3649
values = arg.split(",")

0 commit comments

Comments
 (0)