Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrepancy Function #317

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
flake fix
  • Loading branch information
hfr1tz3 committed Oct 9, 2023
commit a7c656ec66365f02d23695b2f3c863a5fe64dc40
72 changes: 48 additions & 24 deletions tsdate/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def node_spans(ts):
not have 'missing data' there).
"""
child_spans = np.bincount(
ts.edges_child,
weights=ts.edges_right - ts.edges_left,
minlength=ts.num_nodes,
ts.edges_child,
weights=ts.edges_right - ts.edges_left,
minlength=ts.num_nodes,
)
for t in ts.trees():
span = t.span
Expand All @@ -276,61 +276,85 @@ def node_spans(ts):

def tree_discrepancy(ts, other):
"""
For two tree sequences `ts` and `other`, this method `tree_discrepancy` returns two values:
1. the sum across the nodes of `ts` of the length of the discrepancy in span between the node and its best match node in `other` weighted by difference in time.
2. The root mean squared difference between the times in the nodes in `ts` and times of their best matching nodes in `other` with the average weighted by the span in `ts`.
For two tree sequences `ts` and `other`,
this method `tree_discrepancy` returns two values:
hfr1tz3 marked this conversation as resolved.
Show resolved Hide resolved
1. the sum across the nodes of `ts`
of the length of the discrepancy in span between the node
and its best match node in `other`
weighted by difference in time.
hfr1tz3 marked this conversation as resolved.
Show resolved Hide resolved
2. The root mean squared difference
between the times in the nodes in `ts`
and times of their best matching nodes in `other`
with the average weighted by the span in `ts`.
hfr1tz3 marked this conversation as resolved.
Show resolved Hide resolved

This is done as follows:

For each node in `ts` the best matching node(s) from `other` has the longest matching span using `shared_node_spans`. If either tree sequence contains unary nodes there may be multiple matches with the same longest shared span for a single node. In this case, the best match is the node closest in time. The discrepancy is:
For each node in `ts` the best matching node(s) from `other`
has the longest matching span using `shared_node_spans`.
If either tree sequence contains unary nodes there may be
multiple matches with the same longest shared span
for a single node.
In this case, the best match is the node closest in time.
hfr1tz3 marked this conversation as resolved.
Show resolved Hide resolved
The discrepancy is:
..math::

d(ts, other) = 1 -
\left(sum_{x\in \operatorname{ts}} \min_{y\in \operatorname{other}} |t_x - t_y| \max{y \in \operatorname{other}} \frac{1}{T}* \operatorname{shared_span}(x,y)\right),
\\left(sum_{x\\in \\operatorname{ts}}
\\min_{y\\in \\operatorname{other}}
|t_x - t_y| \\max{y \\in \\operatorname{other}}
\frac{1}{T}* \\operatorname{shared_span}(x,y)\right),

where :math: `T` is the sum of spans of all nodes in `ts`.

Returns two values:
`discrepancy` (float) the value computed above.
`root-mean-squared discrepancy` (float)
`root-mean-squared discrepancy` (float)
"""

shared_spans = shared_node_spans(ts, other)
# Find all potential matches for a node based on max shared span length
max_span = shared_spans.max(axis=1).toarray().flatten()
col_ind = shared_spans.indices
row_ind = np.repeat(np.arange(shared_spans.shape[0]),
repeats = np.diff(shared_spans.indptr))

row_ind = np.repeat(
np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr)
)

match = shared_spans.data == max_span[row_ind]
# Construct a matrix of potiential matches and
# scale with difference in node times
match_matrix = scipy.sparse.coo_matrix(
(shared_spans.data[match], (row_ind[match], col_ind[match])),
shape = (ts.num_nodes, other.num_nodes),
shape=(ts.num_nodes, other.num_nodes),
)
ts_times = ts.nodes_time[row_ind[match]]
other_times = other.nodes_time[col_ind[match]]
time_matrix = scipy.sparse.coo_matrix(
(np.absolute(np.asarray(ts_times-other_times)), (row_ind[match], col_ind[match])),
shape = (ts.num_nodes, other.num_nodes),
(
np.absolute(np.asarray(ts_times - other_times)),
(row_ind[match], col_ind[match]),
),
shape=(ts.num_nodes, other.num_nodes),
)
discrepancy_matrix = match_matrix.multiply(time_matrix).tocsr()
# determine best matches with the following matrix
m = scipy.sparse.csr_matrix(
(1/(1 + discrepancy_matrix.data), (discrepancy_matrix.indices)),
shape = (ts.num_nodes, other.num_nodes)
(1 / (1 + discrepancy_matrix.data), (discrepancy_matrix.indices)),
shape=(ts.num_nodes, other.num_nodes),
)
# Between each pair of nodes, find the maximum shared span
best_match = m.argmax(axis=1).A1
best_match_spans = shared_spans[np.linspace(len(best_match)), best_match].reshape(-1)
best_match_spans = shared_spans[np.linspace(len(best_match)), best_match].reshape(
-1
)
# Return the discrepancy between ts and other
node_spans = node_span(ts)
total_node_spans = np.sum(node_spans)
discrepancy = 1 - np.sum(best_match_spans)/total_node_spans
ts_node_spans = node_spans(ts)
total_node_spans = np.sum(ts_node_spans)
discrepancy = 1 - np.sum(best_match_spans) / total_node_spans
# Compute the root-mean-square discrepancy in time
# with averaged weighted by span in ts
time_discrepancies = time_matrix[np.linspace(len(best_match)), best_match].reshape(-1)
rmse = np.sqrt(np.sum(time_discrepancies**2*node_spans)/total_node_spans)
time_discrepancies = time_matrix[np.linspace(len(best_match)), best_match].reshape(
-1
)
rmse = np.sqrt(np.sum(time_discrepancies**2 * ts_node_spans) / total_node_spans)

return discrepancy, rmse