Skip to content

Commit e8a3e0d

Browse files
Finish up implementation of samples arg
1 parent 2895f65 commit e8a3e0d

File tree

6 files changed

+109
-28
lines changed

6 files changed

+109
-28
lines changed

c/tests/test_trees.c

+13-2
Original file line numberDiff line numberDiff line change
@@ -3808,11 +3808,22 @@ test_simplest_divergence_matrix(void)
38083808

38093809
tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0);
38103810

3811-
/* ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, D); */
38123811
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
38133812
CU_ASSERT_EQUAL_FATAL(ret, 0);
38143813
assert_arrays_almost_equal(4, D, result);
38153814

3815+
ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, 0, D);
3816+
CU_ASSERT_EQUAL_FATAL(ret, 0);
3817+
assert_arrays_almost_equal(4, D, result);
3818+
3819+
sample_ids[0] = -1;
3820+
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
3821+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);
3822+
3823+
sample_ids[0] = 3;
3824+
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
3825+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);
3826+
38163827
tsk_treeseq_free(&ts);
38173828
}
38183829

@@ -3861,7 +3872,7 @@ test_simplest_divergence_matrix_internal_sample(void)
38613872
{
38623873
const char *nodes = "1 0 0\n"
38633874
"1 0 0\n"
3864-
"1 1 0\n";
3875+
"0 1 0\n";
38653876
const char *edges = "0 1 2 0,1\n";
38663877
tsk_treeseq_t ts;
38673878
tsk_id_t sample_ids[] = { 0, 1, 2 };

c/tskit/trees.c

+26-1
Original file line numberDiff line numberDiff line change
@@ -6252,14 +6252,34 @@ sv_tables_mrca(const sv_tables_t *self, tsk_id_t x, tsk_id_t y)
62526252
return sv_tables_mrca_one_based(self, x + 1, y + 1) - 1;
62536253
}
62546254

6255+
static int
6256+
tsk_treeseq_check_node_bounds(
6257+
const tsk_treeseq_t *self, tsk_size_t num_nodes, const tsk_id_t *nodes)
6258+
{
6259+
int ret = 0;
6260+
tsk_size_t j;
6261+
tsk_id_t u;
6262+
const tsk_id_t N = (tsk_id_t) self->tables->nodes.num_rows;
6263+
6264+
for (j = 0; j < num_nodes; j++) {
6265+
u = nodes[j];
6266+
if (u < 0 || u >= N) {
6267+
ret = TSK_ERR_NODE_OUT_OF_BOUNDS;
6268+
goto out;
6269+
}
6270+
}
6271+
out:
6272+
return ret;
6273+
}
6274+
62556275
int
62566276
tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
62576277
const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows,
62586278
tsk_flags_t TSK_UNUSED(options), double *result)
62596279
{
62606280
int ret = 0;
62616281
tsk_tree_t tree;
6262-
const tsk_id_t *samples = self->samples;
6282+
const tsk_id_t *restrict samples = self->samples;
62636283
const double default_windows[] = { 0, self->tables->sequence_length };
62646284
const double *restrict nodes_time = self->tables->nodes.time;
62656285
tsk_size_t n = self->num_samples;
@@ -6292,7 +6312,12 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
62926312
if (samples_in != NULL) {
62936313
samples = samples_in;
62946314
n = num_samples;
6315+
ret = tsk_treeseq_check_node_bounds(self, n, samples);
6316+
if (ret != 0) {
6317+
goto out;
6318+
}
62956319
}
6320+
62966321
memset(result, 0, num_windows * n * n * sizeof(*result));
62976322

62986323
for (i = 0; i < num_windows; i++) {

python/_tskitmodule.c

+21-5
Original file line numberDiff line numberDiff line change
@@ -9641,25 +9641,38 @@ static PyObject *
96419641
TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
96429642
{
96439643
PyObject *ret = NULL;
9644-
static char *kwlist[] = { "windows", NULL };
9644+
static char *kwlist[] = { "windows", "samples", NULL };
96459645
PyArrayObject *result_array = NULL;
96469646
PyObject *windows = NULL;
9647+
PyObject *py_samples = Py_None;
96479648
PyArrayObject *windows_array = NULL;
9649+
PyArrayObject *samples_array = NULL;
96489650
tsk_flags_t options = 0;
9649-
npy_intp dims[3];
9651+
npy_intp *shape, dims[3];
96509652
tsk_size_t num_samples, num_windows;
9653+
tsk_id_t *samples = NULL;
96519654
int err;
96529655

96539656
if (TreeSequence_check_state(self) != 0) {
96549657
goto out;
96559658
}
9656-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &windows)) {
9659+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &windows, &py_samples)) {
96579660
goto out;
96589661
}
9662+
num_samples = tsk_treeseq_get_num_samples(self->tree_sequence);
9663+
if (py_samples != Py_None) {
9664+
samples_array = (PyArrayObject *) PyArray_FROMANY(
9665+
py_samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY);
9666+
if (samples_array == NULL) {
9667+
goto out;
9668+
}
9669+
shape = PyArray_DIMS(samples_array);
9670+
samples = PyArray_DATA(samples_array);
9671+
num_samples = (tsk_size_t) shape[0];
9672+
}
96599673
if (parse_windows(windows, &windows_array, &num_windows) != 0) {
96609674
goto out;
96619675
}
9662-
num_samples = tsk_treeseq_get_num_samples(self->tree_sequence);
96639676
dims[0] = num_windows;
96649677
dims[1] = num_samples;
96659678
dims[2] = num_samples;
@@ -9670,7 +9683,8 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
96709683
// clang-format off
96719684
Py_BEGIN_ALLOW_THREADS
96729685
err = tsk_treeseq_divergence_matrix(
9673-
self->tree_sequence, 0, NULL,
9686+
self->tree_sequence,
9687+
num_samples, samples,
96749688
num_windows, PyArray_DATA(windows_array),
96759689
options, PyArray_DATA(result_array));
96769690
Py_END_ALLOW_THREADS
@@ -9685,6 +9699,8 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
96859699
result_array = NULL;
96869700
out:
96879701
Py_XDECREF(result_array);
9702+
Py_XDECREF(windows_array);
9703+
/* Py_XDECREF(samples_array); */
96889704
return ret;
96899705
}
96909706

python/tests/test_divmat.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -260,14 +260,18 @@ def check_divmat(
260260

261261
D1 = divergence_matrix(ts, windows=windows, samples=samples)
262262
if compare_stats_api:
263+
# Somethings like duplicate samples aren't worth hacking around for in
264+
# stats API.
263265
D2 = lib_divergence_matrix(ts, windows=windows, samples=samples)
264266
# print("windows = ", windows)
265267
# print(D1)
266268
# print(D2)
267269
np.testing.assert_allclose(D1, D2)
268-
# D3 = ts.divergence_matrix(windows=windows)
269-
# # print(D3)
270-
# np.testing.assert_allclose(D1, D3)
270+
assert D1.shape == D2.shape
271+
D3 = ts.divergence_matrix(windows=windows, samples=samples)
272+
# print(D3)
273+
assert D1.shape == D3.shape
274+
np.testing.assert_allclose(D1, D3)
271275
return D1
272276

273277

@@ -579,9 +583,9 @@ def test_disconnected_non_sample_topology(self):
579583

580584

581585
class TestThreadsNoWindows:
582-
def check(self, ts, num_threads):
583-
D1 = ts.divergence_matrix(num_threads=0)
584-
D2 = ts.divergence_matrix(num_threads=num_threads)
586+
def check(self, ts, num_threads, samples=None):
587+
D1 = ts.divergence_matrix(num_threads=0, samples=samples)
588+
D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples)
585589
np.testing.assert_array_almost_equal(D1, D2)
586590

587591
@pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27])
@@ -590,6 +594,12 @@ def test_all_trees(self, num_threads):
590594
assert ts.num_trees == 26
591595
self.check(ts, num_threads)
592596

597+
@pytest.mark.parametrize("samples", [None, [0, 1]])
598+
def test_all_trees_samples(self, samples):
599+
ts = tsutil.all_trees_ts(4)
600+
assert ts.num_trees == 26
601+
self.check(ts, 2, samples)
602+
593603
@pytest.mark.parametrize("n", [2, 3, 5, 15])
594604
@pytest.mark.parametrize("num_threads", range(1, 5))
595605
def test_simple_sims(self, n, num_threads):
@@ -606,9 +616,11 @@ def test_simple_sims(self, n, num_threads):
606616

607617

608618
class TestThreadsWindows:
609-
def check(self, ts, num_threads, *, windows):
610-
D1 = ts.divergence_matrix(num_threads=0, windows=windows)
611-
D2 = ts.divergence_matrix(num_threads=num_threads, windows=windows)
619+
def check(self, ts, num_threads, *, windows, samples=None):
620+
D1 = ts.divergence_matrix(num_threads=0, windows=windows, samples=samples)
621+
D2 = ts.divergence_matrix(
622+
num_threads=num_threads, windows=windows, samples=samples
623+
)
612624
np.testing.assert_array_almost_equal(D1, D2)
613625

614626
@pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27])
@@ -628,6 +640,18 @@ def test_all_trees(self, num_threads, windows):
628640
assert ts.num_trees == 26
629641
self.check(ts, num_threads, windows=windows)
630642

643+
@pytest.mark.parametrize("samples", [None, [0, 1]])
644+
@pytest.mark.parametrize(
645+
["windows"],
646+
[
647+
([0, 26],),
648+
(None,),
649+
],
650+
)
651+
def test_all_trees_samples(self, samples, windows):
652+
ts = tsutil.all_trees_ts(4)
653+
self.check(ts, 2, windows=windows, samples=samples)
654+
631655
@pytest.mark.parametrize("num_threads", range(1, 5))
632656
@pytest.mark.parametrize(
633657
["windows"],

python/tests/test_lowlevel.py

+4
Original file line numberDiff line numberDiff line change
@@ -1534,12 +1534,16 @@ def test_divergence_matrix(self):
15341534
ts = self.get_example_tree_sequence(n, random_seed=12)
15351535
D = ts.divergence_matrix([0, ts.get_sequence_length()])
15361536
assert D.shape == (1, n, n)
1537+
D = ts.divergence_matrix([0, ts.get_sequence_length()], samples=[0, 1])
1538+
assert D.shape == (1, 2, 2)
15371539
with pytest.raises(TypeError):
15381540
ts.divergence_matrix(windoze=[0, 1])
15391541
with pytest.raises(ValueError, match="at least 2"):
15401542
ts.divergence_matrix(windows=[0])
15411543
with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"):
15421544
ts.divergence_matrix(windows=[-1, 0, 1])
1545+
with pytest.raises(ValueError):
1546+
ts.divergence_matrix(windows=[0, 1], samples="sdf")
15431547

15441548
def test_load_tables_build_indexes(self):
15451549
for ts in self.get_example_tree_sequences():

python/tskit/trees.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -7770,24 +7770,24 @@ def _chunk_windows(windows, num_chunks):
77707770
# k += 1
77717771
# return A
77727772

7773-
# NOTE see older definition above that we didn't finish up. Are there things
7774-
# we should take from this?
7773+
# NOTE see older definition of divmat above that we didn't finish up.
7774+
# Are there things we should take from this?
77757775

7776-
def _parallelise_divmat_by_tree(self, num_threads):
7776+
def _parallelise_divmat_by_tree(self, num_threads, samples):
77777777
"""
77787778
No windows were specified, so we can chunk up the whole genome by
77797779
tree, and do a simple sum of the results.
77807780
"""
77817781

77827782
def worker(interval):
7783-
return self._ll_tree_sequence.divergence_matrix(interval)
7783+
return self._ll_tree_sequence.divergence_matrix(interval, samples=samples)
77847784

77857785
work = self._chunk_sequence_by_tree(num_threads)
77867786
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as pool:
77877787
results = pool.map(worker, work)
77887788
return sum(results)
77897789

7790-
def _parallelise_divmat_by_window(self, windows, num_threads):
7790+
def _parallelise_divmat_by_window(self, windows, num_threads, samples):
77917791
"""
77927792
We assume we have a number of windows that's >= to the number
77937793
of threads available, and let each thread have a chunk of the
@@ -7797,28 +7797,29 @@ def _parallelise_divmat_by_window(self, windows, num_threads):
77977797
"""
77987798

77997799
def worker(sub_windows):
7800-
return self._ll_tree_sequence.divergence_matrix(sub_windows)
7800+
return self._ll_tree_sequence.divergence_matrix(
7801+
sub_windows, samples=samples
7802+
)
78017803

78027804
work = self._chunk_windows(windows, num_threads)
78037805
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
78047806
futures = [executor.submit(worker, sub_windows) for sub_windows in work]
78057807
concurrent.futures.wait(futures)
78067808
return np.vstack([future.result() for future in futures])
78077809

7808-
def divergence_matrix(self, *, windows=None, num_threads=0):
7809-
# TODO implement "samples" argument
7810+
def divergence_matrix(self, *, windows=None, samples=None, num_threads=0):
78107811
windows_specified = windows is not None
78117812
windows = [0, self.sequence_length] if windows is None else windows
78127813

78137814
# NOTE: maybe we want to use a different default for num_threads here, just
78147815
# following the approach in GNN
78157816
if num_threads <= 0:
7816-
D = self._ll_tree_sequence.divergence_matrix(windows)
7817+
D = self._ll_tree_sequence.divergence_matrix(windows, samples=samples)
78177818
else:
78187819
if windows_specified:
7819-
D = self._parallelise_divmat_by_window(windows, num_threads)
7820+
D = self._parallelise_divmat_by_window(windows, num_threads, samples)
78207821
else:
7821-
D = self._parallelise_divmat_by_tree(num_threads)
7822+
D = self._parallelise_divmat_by_tree(num_threads, samples)
78227823

78237824
if not windows_specified:
78247825
# Drop the windows dimension

0 commit comments

Comments
 (0)