Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions _tsinfermodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ AncestorBuilder_add_site(AncestorBuilder *self, PyObject *args, PyObject *kwds)
&site_id, &age, &PyArray_Type, &genotypes)) {
goto out;
}
genotypes_array = (PyArrayObject *) PyArray_FROM_OTF(genotypes, NPY_UINT8,
genotypes_array = (PyArrayObject *) PyArray_FROM_OTF(genotypes, NPY_INT8,
NPY_ARRAY_IN_ARRAY);
if (genotypes_array == NULL) {
goto out;
Expand Down Expand Up @@ -188,7 +188,7 @@ AncestorBuilder_make_ancestor(AncestorBuilder *self, PyObject *args, PyObject *k
PyErr_SetString(PyExc_ValueError, "num_focal_sites must > 0 and <= num_sites");
goto fail;
}
ancestor_array = (PyArrayObject *) PyArray_FROM_OTF(ancestor, NPY_UINT8,
ancestor_array = (PyArrayObject *) PyArray_FROM_OTF(ancestor, NPY_INT8,
NPY_ARRAY_INOUT_ARRAY);
if (ancestor_array == NULL) {
goto fail;
Expand Down Expand Up @@ -611,7 +611,7 @@ TreeSequenceBuilder_add_mutations(TreeSequenceBuilder *self, PyObject *args, PyO
num_mutations = shape[0];

/* derived_state */
derived_state_array = (PyArrayObject *) PyArray_FROM_OTF(derived_state, NPY_UINT8,
derived_state_array = (PyArrayObject *) PyArray_FROM_OTF(derived_state, NPY_INT8,
NPY_ARRAY_IN_ARRAY);
if (derived_state_array == NULL) {
goto out;
Expand Down Expand Up @@ -1308,7 +1308,7 @@ AncestorMatcher_find_path(AncestorMatcher *self, PyObject *args, PyObject *kwds)
&haplotype, &start, &end, &PyArray_Type, &match)) {
goto out;
}
haplotype_array = (PyArrayObject *) PyArray_FROM_OTF(haplotype, NPY_UINT8,
haplotype_array = (PyArrayObject *) PyArray_FROM_OTF(haplotype, NPY_INT8,
NPY_ARRAY_IN_ARRAY);
if (haplotype_array == NULL) {
goto out;
Expand All @@ -1323,7 +1323,7 @@ AncestorMatcher_find_path(AncestorMatcher *self, PyObject *args, PyObject *kwds)
goto out;
}

match_array = (PyArrayObject *) PyArray_FROM_OTF(match, NPY_UINT8,
match_array = (PyArrayObject *) PyArray_FROM_OTF(match, NPY_INT8,
NPY_ARRAY_INOUT_ARRAY);
if (match_array == NULL) {
goto out;
Expand Down
18 changes: 8 additions & 10 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import numpy as np

import tsinfer
import tsinfer.constants as constants


def get_smc_simulation(n, L=1, recombination_rate=0, seed=1):
Expand Down Expand Up @@ -300,8 +299,7 @@ def get_matrix(self, ts):
"""
Simple implementation using tree traversals.
"""
A = np.zeros((ts.num_nodes, ts.num_sites), dtype=np.uint8)
A[:] = constants.UNKNOWN_ALLELE
A = np.full((ts.num_nodes, ts.num_sites), tskit.MISSING_DATA, dtype=np.int8)
for t in ts.trees():
for site in t.sites():
for u in t.nodes():
Expand Down Expand Up @@ -338,7 +336,7 @@ def verify_haplotypes(self, ts, A):
self.assertTrue(np.all(A[above, site.id] == 0))
outside = np.array(list(
set(range(ts.num_nodes)) - set(tree.nodes())), dtype=int)
self.assertTrue(np.all(A[outside, site.id] == constants.UNKNOWN_ALLELE))
self.assertTrue(np.all(A[outside, site.id] == tskit.MISSING_DATA))

def test_single_tree(self):
ts = msprime.simulate(5, mutation_rate=10, random_seed=234)
Expand Down Expand Up @@ -423,9 +421,9 @@ def verify_many_trees_dense_mutations(self, ts):
self.assertTrue(np.all(ancestors[0, :] == 0))
for a, s, e, focal in zip(ancestors[1:], start[1:], end[1:], focal_sites[1:]):
self.assertTrue(0 <= s < e <= m)
self.assertTrue(np.all(a[:s] == constants.UNKNOWN_ALLELE))
self.assertTrue(np.all(a[e:] == constants.UNKNOWN_ALLELE))
self.assertTrue(np.all(a[s:e] != constants.UNKNOWN_ALLELE))
self.assertTrue(np.all(a[:s] == tskit.MISSING_DATA))
self.assertTrue(np.all(a[e:] == tskit.MISSING_DATA))
self.assertTrue(np.all(a[s:e] != tskit.MISSING_DATA))
for site in focal:
self.assertEqual(a[site], 1)

Expand Down Expand Up @@ -906,7 +904,7 @@ def test_inferred_random_data(self):
np.random.seed(10)
num_sites = 40
num_samples = 8
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.uint8)
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.int8)
with tsinfer.SampleData() as sample_data:
for j in range(num_sites):
sample_data.add_site(j, G[j])
Expand Down Expand Up @@ -978,7 +976,7 @@ def two_populations_high_migration_example(self, mutation_rate=10):

def get_random_data_example(self, num_sites, num_samples, seed=100):
np.random.seed(seed)
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.uint8)
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.int8)
with tsinfer.SampleData() as sample_data:
for j in range(num_sites):
sample_data.add_site(j, G[j])
Expand Down Expand Up @@ -1126,7 +1124,7 @@ def test_many_trees(self):

def get_random_data_example(self, position, num_samples, seed=100):
np.random.seed(seed)
G = np.random.randint(2, size=(position.shape[0], num_samples)).astype(np.uint8)
G = np.random.randint(2, size=(position.shape[0], num_samples)).astype(np.int8)
with tsinfer.SampleData() as sample_data:
for j, x in enumerate(position):
sample_data.add_site(x, G[j])
Expand Down
23 changes: 12 additions & 11 deletions tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import numcodecs
import numcodecs.blosc as blosc
import zarr
import tskit

import tsinfer
import tsinfer.formats as formats
Expand Down Expand Up @@ -89,7 +90,7 @@ def verify_data_round_trip(self, ts, input_file):
self.assertEqual(input_file.num_samples, ts.num_samples)
self.assertEqual(input_file.sequence_length, ts.sequence_length)
self.assertEqual(input_file.num_sites, ts.num_sites)
self.assertEqual(input_file.sites_genotypes.dtype, np.uint8)
self.assertEqual(input_file.sites_genotypes.dtype, np.int8)
self.assertEqual(input_file.sites_position.dtype, np.float64)
# Take copies to avoid decompressing the data repeatedly.
genotypes = input_file.sites_genotypes[:]
Expand Down Expand Up @@ -240,14 +241,14 @@ def test_provenance(self):

def test_variant_errors(self):
input_file = formats.SampleData(sequence_length=10)
genotypes = np.zeros(2, np.uint8)
genotypes = np.zeros(2, np.int8)
input_file.add_site(0, alleles=["0", "1"], genotypes=genotypes)
for bad_position in [-1, 10, 100]:
self.assertRaises(
ValueError, input_file.add_site, position=bad_position,
alleles=["0", "1"], genotypes=genotypes)
for bad_genotypes in [[0, 2], [-1, 0], [], [0], [0, 0, 0]]:
genotypes = np.array(bad_genotypes, dtype=np.uint8)
genotypes = np.array(bad_genotypes, dtype=np.int8)
self.assertRaises(
ValueError, input_file.add_site, position=1,
alleles=["0", "1"], genotypes=genotypes)
Expand Down Expand Up @@ -742,15 +743,15 @@ def get_example_data(self, sample_size, sequence_length, num_ancestors):
num_sites = sample_data.num_inference_sites
ancestors = []
for j in range(num_ancestors):
haplotype = np.zeros(num_sites, dtype=np.uint8) + tsinfer.UNKNOWN_ALLELE
haplotype = np.full(num_sites, tskit.MISSING_DATA, dtype=np.int8)
start = j
end = max(num_sites - j, start + 1)
self.assertLess(start, end)
haplotype[start: end] = 0
if start + j < end:
haplotype[start + j: end] = 1
self.assertTrue(np.all(haplotype[:start] == tsinfer.UNKNOWN_ALLELE))
self.assertTrue(np.all(haplotype[end:] == tsinfer.UNKNOWN_ALLELE))
self.assertTrue(np.all(haplotype[:start] == tskit.MISSING_DATA))
self.assertTrue(np.all(haplotype[end:] == tskit.MISSING_DATA))
focal_sites = np.array([start + k for k in range(j)], dtype=np.int32)
focal_sites = focal_sites[focal_sites < end]
haplotype[focal_sites] = 1
Expand Down Expand Up @@ -912,26 +913,26 @@ def test_add_ancestor_errors(self):
self.assertRaises(
ValueError, ancestor_data.add_ancestor,
start=0, end=num_sites, age=1, focal_sites=[],
haplotype=np.zeros(num_sites + 1, dtype=np.uint8))
haplotype=np.zeros(num_sites + 1, dtype=np.int8))
# Haplotypes must be < 2
self.assertRaises(
ValueError, ancestor_data.add_ancestor,
start=0, end=num_sites, age=1, focal_sites=[],
haplotype=np.zeros(num_sites, dtype=np.uint8) + 2)
haplotype=np.full(num_sites, 2, dtype=np.int8))
# focal sites must be within start:end
self.assertRaises(
ValueError, ancestor_data.add_ancestor,
start=1, end=num_sites, age=1, focal_sites=[0],
haplotype=np.ones(num_sites - 1, dtype=np.uint8))
haplotype=np.ones(num_sites - 1, dtype=np.int8))
self.assertRaises(
ValueError, ancestor_data.add_ancestor,
start=0, end=num_sites - 2, age=1, focal_sites=[num_sites - 1],
haplotype=np.ones(num_sites, dtype=np.uint8))
haplotype=np.ones(num_sites, dtype=np.int8))
# focal sites must be set to 1
self.assertRaises(
ValueError, ancestor_data.add_ancestor,
start=0, end=num_sites, age=1, focal_sites=[0],
haplotype=np.zeros(num_sites, dtype=np.uint8))
haplotype=np.zeros(num_sites, dtype=np.int8))

@unittest.skipIf(sys.platform == "win32",
"windows simultaneous file permissions issue")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def verify_ancestors(self, sample_data, ancestor_data):
self.assertEqual(a.shape[0], end[j] - start[j])
h = np.zeros(ancestor_data.num_sites, dtype=np.uint8)
h[start[j]: end[j]] = a
self.assertTrue(np.all(h[start[j]:end[j]] != tsinfer.UNKNOWN_ALLELE))
self.assertTrue(np.all(h[start[j]:end[j]] != tskit.MISSING_DATA))
self.assertTrue(np.all(h[focal_sites[j]] == 1))
used_sites.extend(focal_sites[j])
self.assertGreater(age[j], 0)
Expand Down
30 changes: 30 additions & 0 deletions tsinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,36 @@
"""

import sys
# Start temporary monkey patch to allow use of functions / constants in tskit v2.0.
# The following lines can be deleted once the master tskit version has been updated
import numpy as np
import tskit
tskit.MISSING_DATA = -1


class util():
@staticmethod
def safe_np_int_cast(int_array, dtype, copy=False): # Copied from v2.0 tskit/util.py
if not isinstance(int_array, np.ndarray):
int_array = np.array(int_array)
copy = False
if int_array.size == 0:
return int_array.astype(dtype, copy=copy)
try:
return int_array.astype(dtype, casting='safe', copy=copy)
except TypeError:
bounds = np.iinfo(dtype)
if np.any(int_array < bounds.min) or np.any(int_array > bounds.max):
raise OverflowError("Cannot convert safely to {} type".format(dtype))
if int_array.dtype.kind == 'i' and np.dtype(dtype).kind == 'u':
casting = 'unsafe'
else:
casting = 'same_kind'
return int_array.astype(dtype, casting=casting, copy=copy)


tskit.util = util
# End temporary monkey patch

if sys.version_info[0] < 3:
raise Exception("Python 3 only")
Expand Down
35 changes: 19 additions & 16 deletions tsinfer/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import collections

import numpy as np
import tskit
import sortedcontainers
import tskit

import tsinfer.constants as constants

Expand Down Expand Up @@ -185,7 +185,7 @@ def make_ancestor(self, focal_sites, a):
# check all focal sites in this ancestor are at the same age
assert all([self.sites[fs].age == focal_age for fs in focal_sites])

a[:] = constants.UNKNOWN_ALLELE
a[:] = tskit.MISSING_DATA
for focal_site in focal_sites:
a[focal_site] = 1
S = set(np.where(self.sites[focal_sites[0]].genotypes == 1)[0])
Expand All @@ -203,29 +203,29 @@ def make_ancestor(self, focal_sites, a):
focal_site = focal_sites[-1]
last_site = self.compute_ancestral_states(
a, focal_site, range(focal_site + 1, self.num_sites))
assert a[last_site] != constants.UNKNOWN_ALLELE
assert a[last_site] != tskit.MISSING_DATA
end = last_site + 1
# Go leftwards
focal_site = focal_sites[0]
last_site = self.compute_ancestral_states(
a, focal_site, range(focal_site - 1, -1, -1))
assert a[last_site] != constants.UNKNOWN_ALLELE
assert a[last_site] != tskit.MISSING_DATA
start = last_site
return start, end

# Version with 1 focal site
# assert len(focal_sites) == 1
# focal_site = focal_sites[0]
# a[:] = constants.UNKNOWN_ALLELE
# a[:] = tskit.MISSING_DATA
# a[focal_site] = 1

# last_site = self.compute_ancestral_states(
# a, focal_site, range(focal_site + 1, self.num_sites))
# assert a[last_site] != constants.UNKNOWN_ALLELE
# assert a[last_site] != tskit.MISSING_DATA
# end = last_site + 1
# last_site = self.compute_ancestral_states(
# a, focal_site, range(focal_site - 1, -1, -1))
# assert a[last_site] != constants.UNKNOWN_ALLELE
# assert a[last_site] != tskit.MISSING_DATA
# start = last_site
# return start, end

Expand Down Expand Up @@ -582,7 +582,7 @@ def dump_mutations(self):
site[j] = l
node[j] = u
derived_state[j] = d
parent[j] = -1
parent[j] = tskit.NULL
if d == 0:
parent[j] = p
j += 1
Expand All @@ -595,7 +595,7 @@ def is_descendant(pi, u, v):
v is on the path to root from u.
"""
ret = False
if v != -1:
if v != tskit.NULL:
w = u
path = []
while w != v and w != tskit.NULL:
Expand Down Expand Up @@ -639,6 +639,9 @@ def print_state(self):
for l in range(self.num_sites):
print(l, self.max_likelihood_node[l], self.traceback[l], sep="\t")

def is_root(self, u):
return self.parent[u] == tskit.NULL

def check_likelihoods(self):
assert len(set(self.likelihood_nodes)) == len(self.likelihood_nodes)
# Every value in L_nodes must be positive.
Expand All @@ -649,7 +652,7 @@ def check_likelihoods(self):
if v >= 0:
assert u in self.likelihood_nodes
# Roots other than 0 should have v == -2
if u != 0 and self.parent[u] == -1 and self.left_child[u] == -1:
if u != 0 and self.is_root(u) and self.left_child[u] == -1:
# print("root: u = ", u, self.parent[u], self.left_child[u])
assert v == -2

Expand Down Expand Up @@ -727,8 +730,8 @@ def compress_likelihoods(self):
for u in old_likelihood_nodes:
# We need to find the likelihood of the parent of u. If this is
# the same as u, we can delete it.
p = self.parent[u]
if p != -1:
if not self.is_root(u):
p = self.parent[u]
cached_paths.append(p)
v = p
while self.likelihood[v] == -1 and L_cache[v] == -1:
Expand Down Expand Up @@ -788,7 +791,7 @@ def insert_edge(self, edge):
self.right_child[p] = c

def is_nonzero_root(self, u):
return u != 0 and self.parent[u] == -1 and self.left_child[u] == -1
return u != 0 and self.is_root(u) and self.left_child[u] == -1

def find_path(self, h, start, end, match):
Il = self.tree_sequence_builder.left_index
Expand Down Expand Up @@ -833,7 +836,7 @@ def find_path(self, h, start, end, match):
assert left < right

for u in range(n):
if self.parent[u] != -1:
if not self.is_root(u):
self.likelihood[u] = -1

last_root = 0
Expand Down Expand Up @@ -940,8 +943,8 @@ def run_traceback(self, start, end, match):
k = M - 1
# Construct the matched haplotype
match[:] = 0
match[:start] = constants.UNKNOWN_ALLELE
match[end:] = constants.UNKNOWN_ALLELE
match[:start] = tskit.MISSING_DATA
match[end:] = tskit.MISSING_DATA
# Reset the tree.
self.parent[:] = -1
self.left_child[:] = -1
Expand Down
4 changes: 1 addition & 3 deletions tsinfer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
# along with tsinfer. If not, see <http://www.gnu.org/licenses/>.
#
"""
Collection of constants used in tsinfer.
Collection of constants used in tsinfer. We also make use of constants defined in tskit.
"""

UNKNOWN_ALLELE = 255

C_ENGINE = "C"
PY_ENGINE = "P"

Expand Down
Loading