Skip to content

Commit

Permalink
Remove debugging inserts
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed Jul 17, 2024
1 parent 82db5a6 commit 9951383
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 32 deletions.
16 changes: 9 additions & 7 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from . import util
from . import variational

logger = logging.getLogger(__name__)

FORMAT_NAME = "tsdate"
DEFAULT_RESCALING_INTERVALS = 1000
DEFAULT_RESCALING_ITERATIONS = 1
Expand Down Expand Up @@ -983,7 +985,7 @@ def __init__(
ts, Ne, approximate_priors=approx, progress=progress
)
else:
logging.info("Using user-specified priors")
logger.info("Using user-specified priors")
if Ne is not None:
raise ValueError(
"Cannot specify population size if specifying priors "
Expand Down Expand Up @@ -1019,7 +1021,7 @@ def get_modified_ts(self, result, eps):
mutations.parent = np.full(mutations.num_rows, tskit.NULL, dtype=np.int32)
tables.time_units = self.time_units
constr_timing -= time.time()
logging.info(f"Constrained node ages in {abs(constr_timing)} seconds")
logger.info(f"Constrained node ages in {abs(constr_timing)} seconds")
# Add posterior mean and variance to node/mutation metadata
meta_timing = time.time()
self.set_time_metadata(
Expand All @@ -1029,15 +1031,15 @@ def get_modified_ts(self, result, eps):
mutations, mut_mean_t, mut_var_t, schemas.default_mutation_schema
)
meta_timing -= time.time()
logging.info(
logger.info(
f"Inserted node and mutation metadata in {abs(meta_timing)} seconds"
)
sort_timing = time.time()
tables.sort()
tables.build_index()
tables.compute_mutation_parents()
sort_timing -= time.time()
logging.info(f"Sorted tree sequence in {abs(sort_timing)} seconds")
logger.info(f"Sorted tree sequence in {abs(sort_timing)} seconds")
return tables.tree_sequence()

def set_time_metadata(self, table, mean, var, default_schema, overwrite=False):
Expand All @@ -1050,9 +1052,9 @@ def set_time_metadata(self, table, mean, var, default_schema, overwrite=False):
md_iter = ({} for _ in range(table.num_rows))
# For speed, assume we don't need to validate
encoder = table.metadata_schema.encode_row
logging.info(f"Set metadata schema on {table_name}")
logger.info(f"Set metadata schema on {table_name}")
else:
logging.warning(
logger.warning(
f"Could not set metadata on {table_name}: "
"data already exists with no schema"
)
Expand All @@ -1073,7 +1075,7 @@ def set_time_metadata(self, table, mean, var, default_schema, overwrite=False):
metadata_array.append(encoder(metadata_dict))
table.packset_metadata(metadata_array)
except tskit.MetadataValidationError as e:
logging.warning(f"Could not set time metadata in {table_name}: {e}")
logger.warning(f"Could not set time metadata in {table_name}: {e}")

def parse_result(self, result, epsilon, extra_posterior_cols=None):
# Construct the tree sequence to return and add other stuff we might want to
Expand Down
3 changes: 1 addition & 2 deletions tsdate/phasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def reallocate_unphased(
i, j = blocks_edges[b]
assert tskit.NULL < i < num_edges and edges_unphased[i]
assert tskit.NULL < j < num_edges and edges_unphased[j]
if np.isnan(mutations_phase[m]): # DEBUG
print("ERR skip nan in phase")
if np.isnan(mutations_phase[m]): # TODO: rare numerical issue
continue
assert 0.0 <= mutations_phase[m] <= 1.0
edges_likelihood[i, 0] += mutations_phase[m]
Expand Down
34 changes: 11 additions & 23 deletions tsdate/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .rescaling import piecewise_scale_posterior
from .util import contains_unary_nodes

logger = logging.getLogger(__name__)

# columns for edge_factors
ROOTWARD = 0 # edge likelihood to parent
Expand Down Expand Up @@ -246,7 +247,7 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True):
self.sizebiased_likelihoods, _ = count_mutations(ts, size_biased=True)
self.sizebiased_likelihoods[:, 1] *= mutation_rate
count_timing -= time.time()
logging.info(f"Extracted mutations in {abs(count_timing)} seconds")
logger.info(f"Extracted mutations in {abs(count_timing)} seconds")

# count mutations in singleton blocks
phase_timing = time.time()
Expand All @@ -260,9 +261,9 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True):
self.block_nodes[1] = self.edge_parents[self.block_edges[:, 1]]
num_unphased = np.sum(self.mutation_blocks != tskit.NULL)
phase_timing -= time.time()
logging.info(f"Found {num_unphased} unphased singleton mutations")
logging.info(f"Split unphased singleton edges into {num_blocks} blocks")
logging.info(f"Phased singletons in {abs(phase_timing)} seconds")
logger.info(f"Found {num_unphased} unphased singleton mutations")
logger.info(f"Split unphased singleton edges into {num_blocks} blocks")
logger.info(f"Phased singletons in {abs(phase_timing)} seconds")

# mutable
self.node_factors = np.zeros((ts.num_nodes, 2, 2))
Expand Down Expand Up @@ -659,19 +660,6 @@ def fixed_projection(x, y):
child_cavity,
edge_likelihood,
)
# DEBUG: nan in phase vector
if unphased and not np.isfinite(mutations_phase[m]):
print(
"ERR\tm:",
m,
"p:",
parent_cavity,
"c:",
child_cavity,
"e:",
edge_likelihood,
)
# /DEBUG: nan in phase vector

return np.nan

Expand Down Expand Up @@ -843,8 +831,8 @@ def run(
)
nodes_timing -= time.time()
skipped_edges = np.sum(np.isnan(self.edge_logconst))
logging.info(f"Skipped {skipped_edges} edges with invalid factors")
logging.info(f"Calculated node posteriors in {abs(nodes_timing)} seconds")
logger.info(f"Skipped {skipped_edges} edges with invalid factors")
logger.info(f"Calculated node posteriors in {abs(nodes_timing)} seconds")

muts_timing = time.time()
mutations_phased = self.mutation_blocks == tskit.NULL
Expand Down Expand Up @@ -878,8 +866,8 @@ def run(
)
muts_timing -= time.time()
skipped_muts = np.sum(np.isnan(self.mutation_posterior[:, 0]))
logging.info(f"Skipped {skipped_muts} mutations with invalid posteriors")
logging.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds")
logger.info(f"Skipped {skipped_muts} mutations with invalid posteriors")
logger.info(f"Calculated mutation posteriors in {abs(muts_timing)} seconds")

singletons = self.mutation_blocks != tskit.NULL
switched_blocks = self.mutation_blocks[singletons]
Expand All @@ -892,7 +880,7 @@ def run(
self.mutation_nodes[singletons] = self.edge_children[switched_edges]
switched = self.mutation_phase < 0.5
self.mutation_phase[switched] = 1 - self.mutation_phase[switched]
logging.info(f"Switched phase of {np.sum(switched)} singletons")
logger.info(f"Switched phase of {np.sum(switched)} singletons")

if rescale_intervals > 0 and rescale_iterations > 0:
rescale_timing = time.time()
Expand All @@ -906,7 +894,7 @@ def run(
rescale_segsites=rescale_segsites,
)
rescale_timing -= time.time()
logging.info(f"Timescale rescaled in {abs(rescale_timing)} seconds")
logger.info(f"Timescale rescaled in {abs(rescale_timing)} seconds")

def node_moments(self):
"""
Expand Down

0 comments on commit 9951383

Please sign in to comment.