Skip to content

Commit 47a0153

Browse files
committed
Switch to (signed) int8 for haplotype storage
1 parent 045c26c commit 47a0153

File tree

9 files changed

+72
-73
lines changed

9 files changed

+72
-73
lines changed

_tsinfermodule.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ AncestorBuilder_add_site(AncestorBuilder *self, PyObject *args, PyObject *kwds)
123123
&site_id, &age, &PyArray_Type, &genotypes)) {
124124
goto out;
125125
}
126-
genotypes_array = (PyArrayObject *) PyArray_FROM_OTF(genotypes, NPY_UINT8,
126+
genotypes_array = (PyArrayObject *) PyArray_FROM_OTF(genotypes, NPY_INT8,
127127
NPY_ARRAY_IN_ARRAY);
128128
if (genotypes_array == NULL) {
129129
goto out;
@@ -188,7 +188,7 @@ AncestorBuilder_make_ancestor(AncestorBuilder *self, PyObject *args, PyObject *k
188188
PyErr_SetString(PyExc_ValueError, "num_focal_sites must > 0 and <= num_sites");
189189
goto fail;
190190
}
191-
ancestor_array = (PyArrayObject *) PyArray_FROM_OTF(ancestor, NPY_UINT8,
191+
ancestor_array = (PyArrayObject *) PyArray_FROM_OTF(ancestor, NPY_INT8,
192192
NPY_ARRAY_INOUT_ARRAY);
193193
if (ancestor_array == NULL) {
194194
goto fail;
@@ -611,7 +611,7 @@ TreeSequenceBuilder_add_mutations(TreeSequenceBuilder *self, PyObject *args, PyO
611611
num_mutations = shape[0];
612612

613613
/* derived_state */
614-
derived_state_array = (PyArrayObject *) PyArray_FROM_OTF(derived_state, NPY_UINT8,
614+
derived_state_array = (PyArrayObject *) PyArray_FROM_OTF(derived_state, NPY_INT8,
615615
NPY_ARRAY_IN_ARRAY);
616616
if (derived_state_array == NULL) {
617617
goto out;
@@ -1308,7 +1308,7 @@ AncestorMatcher_find_path(AncestorMatcher *self, PyObject *args, PyObject *kwds)
13081308
&haplotype, &start, &end, &PyArray_Type, &match)) {
13091309
goto out;
13101310
}
1311-
haplotype_array = (PyArrayObject *) PyArray_FROM_OTF(haplotype, NPY_UINT8,
1311+
haplotype_array = (PyArrayObject *) PyArray_FROM_OTF(haplotype, NPY_INT8,
13121312
NPY_ARRAY_IN_ARRAY);
13131313
if (haplotype_array == NULL) {
13141314
goto out;
@@ -1323,7 +1323,7 @@ AncestorMatcher_find_path(AncestorMatcher *self, PyObject *args, PyObject *kwds)
13231323
goto out;
13241324
}
13251325

1326-
match_array = (PyArrayObject *) PyArray_FROM_OTF(match, NPY_UINT8,
1326+
match_array = (PyArrayObject *) PyArray_FROM_OTF(match, NPY_INT8,
13271327
NPY_ARRAY_INOUT_ARRAY);
13281328
if (match_array == NULL) {
13291329
goto out;

tests/test_evaluation.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import numpy as np
3131

3232
import tsinfer
33-
import tsinfer.constants as constants
3433

3534

3635
def get_smc_simulation(n, L=1, recombination_rate=0, seed=1):
@@ -300,8 +299,7 @@ def get_matrix(self, ts):
300299
"""
301300
Simple implementation using tree traversals.
302301
"""
303-
A = np.zeros((ts.num_nodes, ts.num_sites), dtype=np.uint8)
304-
A[:] = constants.UNKNOWN_ALLELE
302+
A = np.full((ts.num_nodes, ts.num_sites), tskit.MISSING_DATA, dtype=np.int8)
305303
for t in ts.trees():
306304
for site in t.sites():
307305
for u in t.nodes():
@@ -338,7 +336,7 @@ def verify_haplotypes(self, ts, A):
338336
self.assertTrue(np.all(A[above, site.id] == 0))
339337
outside = np.array(list(
340338
set(range(ts.num_nodes)) - set(tree.nodes())), dtype=int)
341-
self.assertTrue(np.all(A[outside, site.id] == constants.UNKNOWN_ALLELE))
339+
self.assertTrue(np.all(A[outside, site.id] == tskit.MISSING_DATA))
342340

343341
def test_single_tree(self):
344342
ts = msprime.simulate(5, mutation_rate=10, random_seed=234)
@@ -423,9 +421,9 @@ def verify_many_trees_dense_mutations(self, ts):
423421
self.assertTrue(np.all(ancestors[0, :] == 0))
424422
for a, s, e, focal in zip(ancestors[1:], start[1:], end[1:], focal_sites[1:]):
425423
self.assertTrue(0 <= s < e <= m)
426-
self.assertTrue(np.all(a[:s] == constants.UNKNOWN_ALLELE))
427-
self.assertTrue(np.all(a[e:] == constants.UNKNOWN_ALLELE))
428-
self.assertTrue(np.all(a[s:e] != constants.UNKNOWN_ALLELE))
424+
self.assertTrue(np.all(a[:s] == tskit.MISSING_DATA))
425+
self.assertTrue(np.all(a[e:] == tskit.MISSING_DATA))
426+
self.assertTrue(np.all(a[s:e] != tskit.MISSING_DATA))
429427
for site in focal:
430428
self.assertEqual(a[site], 1)
431429

@@ -906,7 +904,7 @@ def test_inferred_random_data(self):
906904
np.random.seed(10)
907905
num_sites = 40
908906
num_samples = 8
909-
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.uint8)
907+
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.int8)
910908
with tsinfer.SampleData() as sample_data:
911909
for j in range(num_sites):
912910
sample_data.add_site(j, G[j])
@@ -978,7 +976,7 @@ def two_populations_high_migration_example(self, mutation_rate=10):
978976

979977
def get_random_data_example(self, num_sites, num_samples, seed=100):
980978
np.random.seed(seed)
981-
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.uint8)
979+
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.int8)
982980
with tsinfer.SampleData() as sample_data:
983981
for j in range(num_sites):
984982
sample_data.add_site(j, G[j])
@@ -1126,7 +1124,7 @@ def test_many_trees(self):
11261124

11271125
def get_random_data_example(self, position, num_samples, seed=100):
11281126
np.random.seed(seed)
1129-
G = np.random.randint(2, size=(position.shape[0], num_samples)).astype(np.uint8)
1127+
G = np.random.randint(2, size=(position.shape[0], num_samples)).astype(np.int8)
11301128
with tsinfer.SampleData() as sample_data:
11311129
for j, x in enumerate(position):
11321130
sample_data.add_site(x, G[j])

tests/test_formats.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import numcodecs
3333
import numcodecs.blosc as blosc
3434
import zarr
35+
import tskit
3536

3637
import tsinfer
3738
import tsinfer.formats as formats
@@ -89,7 +90,7 @@ def verify_data_round_trip(self, ts, input_file):
8990
self.assertEqual(input_file.num_samples, ts.num_samples)
9091
self.assertEqual(input_file.sequence_length, ts.sequence_length)
9192
self.assertEqual(input_file.num_sites, ts.num_sites)
92-
self.assertEqual(input_file.sites_genotypes.dtype, np.uint8)
93+
self.assertEqual(input_file.sites_genotypes.dtype, np.int8)
9394
self.assertEqual(input_file.sites_position.dtype, np.float64)
9495
# Take copies to avoid decompressing the data repeatedly.
9596
genotypes = input_file.sites_genotypes[:]
@@ -240,14 +241,14 @@ def test_provenance(self):
240241

241242
def test_variant_errors(self):
242243
input_file = formats.SampleData(sequence_length=10)
243-
genotypes = np.zeros(2, np.uint8)
244+
genotypes = np.zeros(2, np.int8)
244245
input_file.add_site(0, alleles=["0", "1"], genotypes=genotypes)
245246
for bad_position in [-1, 10, 100]:
246247
self.assertRaises(
247248
ValueError, input_file.add_site, position=bad_position,
248249
alleles=["0", "1"], genotypes=genotypes)
249250
for bad_genotypes in [[0, 2], [-1, 0], [], [0], [0, 0, 0]]:
250-
genotypes = np.array(bad_genotypes, dtype=np.uint8)
251+
genotypes = np.array(bad_genotypes, dtype=np.int8)
251252
self.assertRaises(
252253
ValueError, input_file.add_site, position=1,
253254
alleles=["0", "1"], genotypes=genotypes)
@@ -785,15 +786,15 @@ def get_example_data(self, sample_size, sequence_length, num_ancestors):
785786
num_sites = sample_data.num_inference_sites
786787
ancestors = []
787788
for j in range(num_ancestors):
788-
haplotype = np.zeros(num_sites, dtype=np.uint8) + tsinfer.UNKNOWN_ALLELE
789+
haplotype = np.full(num_sites, tskit.MISSING_DATA, dtype=np.int8)
789790
start = j
790791
end = max(num_sites - j, start + 1)
791792
self.assertLess(start, end)
792793
haplotype[start: end] = 0
793794
if start + j < end:
794795
haplotype[start + j: end] = 1
795-
self.assertTrue(np.all(haplotype[:start] == tsinfer.UNKNOWN_ALLELE))
796-
self.assertTrue(np.all(haplotype[end:] == tsinfer.UNKNOWN_ALLELE))
796+
self.assertTrue(np.all(haplotype[:start] == tskit.MISSING_DATA))
797+
self.assertTrue(np.all(haplotype[end:] == tskit.MISSING_DATA))
797798
focal_sites = np.array([start + k for k in range(j)], dtype=np.int32)
798799
focal_sites = focal_sites[focal_sites < end]
799800
haplotype[focal_sites] = 1
@@ -955,26 +956,26 @@ def test_add_ancestor_errors(self):
955956
self.assertRaises(
956957
ValueError, ancestor_data.add_ancestor,
957958
start=0, end=num_sites, age=1, focal_sites=[],
958-
haplotype=np.zeros(num_sites + 1, dtype=np.uint8))
959+
haplotype=np.zeros(num_sites + 1, dtype=np.int8))
959960
# Haplotypes must be < 2
960961
self.assertRaises(
961962
ValueError, ancestor_data.add_ancestor,
962963
start=0, end=num_sites, age=1, focal_sites=[],
963-
haplotype=np.zeros(num_sites, dtype=np.uint8) + 2)
964+
haplotype=np.full(num_sites, 2, dtype=np.int8))
964965
# focal sites must be within start:end
965966
self.assertRaises(
966967
ValueError, ancestor_data.add_ancestor,
967968
start=1, end=num_sites, age=1, focal_sites=[0],
968-
haplotype=np.ones(num_sites - 1, dtype=np.uint8))
969+
haplotype=np.ones(num_sites - 1, dtype=np.int8))
969970
self.assertRaises(
970971
ValueError, ancestor_data.add_ancestor,
971972
start=0, end=num_sites - 2, age=1, focal_sites=[num_sites - 1],
972-
haplotype=np.ones(num_sites, dtype=np.uint8))
973+
haplotype=np.ones(num_sites, dtype=np.int8))
973974
# focal sites must be set to 1
974975
self.assertRaises(
975976
ValueError, ancestor_data.add_ancestor,
976977
start=0, end=num_sites, age=1, focal_sites=[0],
977-
haplotype=np.zeros(num_sites, dtype=np.uint8))
978+
haplotype=np.zeros(num_sites, dtype=np.int8))
978979

979980
@unittest.skipIf(sys.platform == "win32",
980981
"windows simultaneous file permissions issue")

tests/test_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def verify_ancestors(self, sample_data, ancestor_data):
693693
self.assertEqual(a.shape[0], end[j] - start[j])
694694
h = np.zeros(ancestor_data.num_sites, dtype=np.uint8)
695695
h[start[j]: end[j]] = a
696-
self.assertTrue(np.all(h[start[j]:end[j]] != tsinfer.UNKNOWN_ALLELE))
696+
self.assertTrue(np.all(h[start[j]:end[j]] != tskit.MISSING_DATA))
697697
self.assertTrue(np.all(h[focal_sites[j]] == 1))
698698
used_sites.extend(focal_sites[j])
699699
self.assertGreater(age[j], 0)

tsinfer/algorithm.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
import collections
2929

3030
import numpy as np
31-
import tskit
3231
import sortedcontainers
32+
import tskit
3333

3434
import tsinfer.constants as constants
3535

@@ -185,7 +185,7 @@ def make_ancestor(self, focal_sites, a):
185185
# check all focal sites in this ancestor are at the same age
186186
assert all([self.sites[fs].age == focal_age for fs in focal_sites])
187187

188-
a[:] = constants.UNKNOWN_ALLELE
188+
a[:] = tskit.MISSING_DATA
189189
for focal_site in focal_sites:
190190
a[focal_site] = 1
191191
S = set(np.where(self.sites[focal_sites[0]].genotypes == 1)[0])
@@ -203,29 +203,29 @@ def make_ancestor(self, focal_sites, a):
203203
focal_site = focal_sites[-1]
204204
last_site = self.compute_ancestral_states(
205205
a, focal_site, range(focal_site + 1, self.num_sites))
206-
assert a[last_site] != constants.UNKNOWN_ALLELE
206+
assert a[last_site] != tskit.MISSING_DATA
207207
end = last_site + 1
208208
# Go leftwards
209209
focal_site = focal_sites[0]
210210
last_site = self.compute_ancestral_states(
211211
a, focal_site, range(focal_site - 1, -1, -1))
212-
assert a[last_site] != constants.UNKNOWN_ALLELE
212+
assert a[last_site] != tskit.MISSING_DATA
213213
start = last_site
214214
return start, end
215215

216216
# Version with 1 focal site
217217
# assert len(focal_sites) == 1
218218
# focal_site = focal_sites[0]
219-
# a[:] = constants.UNKNOWN_ALLELE
219+
# a[:] = tskit.MISSING_DATA
220220
# a[focal_site] = 1
221221

222222
# last_site = self.compute_ancestral_states(
223223
# a, focal_site, range(focal_site + 1, self.num_sites))
224-
# assert a[last_site] != constants.UNKNOWN_ALLELE
224+
# assert a[last_site] != tskit.MISSING_DATA
225225
# end = last_site + 1
226226
# last_site = self.compute_ancestral_states(
227227
# a, focal_site, range(focal_site - 1, -1, -1))
228-
# assert a[last_site] != constants.UNKNOWN_ALLELE
228+
# assert a[last_site] != tskit.MISSING_DATA
229229
# start = last_site
230230
# return start, end
231231

@@ -582,7 +582,7 @@ def dump_mutations(self):
582582
site[j] = l
583583
node[j] = u
584584
derived_state[j] = d
585-
parent[j] = -1
585+
parent[j] = tskit.NULL
586586
if d == 0:
587587
parent[j] = p
588588
j += 1
@@ -595,7 +595,7 @@ def is_descendant(pi, u, v):
595595
v is on the path to root from u.
596596
"""
597597
ret = False
598-
if v != -1:
598+
if v != tskit.NULL:
599599
w = u
600600
path = []
601601
while w != v and w != tskit.NULL:
@@ -639,6 +639,9 @@ def print_state(self):
639639
for l in range(self.num_sites):
640640
print(l, self.max_likelihood_node[l], self.traceback[l], sep="\t")
641641

642+
def is_root(self, u):
643+
return self.parent[u] == tskit.NULL
644+
642645
def check_likelihoods(self):
643646
assert len(set(self.likelihood_nodes)) == len(self.likelihood_nodes)
644647
# Every value in L_nodes must be positive.
@@ -649,7 +652,7 @@ def check_likelihoods(self):
649652
if v >= 0:
650653
assert u in self.likelihood_nodes
651654
# Roots other than 0 should have v == -2
652-
if u != 0 and self.parent[u] == -1 and self.left_child[u] == -1:
655+
if u != 0 and self.is_root(u) and self.left_child[u] == -1:
653656
# print("root: u = ", u, self.parent[u], self.left_child[u])
654657
assert v == -2
655658

@@ -727,8 +730,8 @@ def compress_likelihoods(self):
727730
for u in old_likelihood_nodes:
728731
# We need to find the likelihood of the parent of u. If this is
729732
# the same as u, we can delete it.
730-
p = self.parent[u]
731-
if p != -1:
733+
if not self.is_root(u):
734+
p = self.parent[u]
732735
cached_paths.append(p)
733736
v = p
734737
while self.likelihood[v] == -1 and L_cache[v] == -1:
@@ -788,7 +791,7 @@ def insert_edge(self, edge):
788791
self.right_child[p] = c
789792

790793
def is_nonzero_root(self, u):
791-
return u != 0 and self.parent[u] == -1 and self.left_child[u] == -1
794+
return u != 0 and self.is_root(u) and self.left_child[u] == -1
792795

793796
def find_path(self, h, start, end, match):
794797
Il = self.tree_sequence_builder.left_index
@@ -833,7 +836,7 @@ def find_path(self, h, start, end, match):
833836
assert left < right
834837

835838
for u in range(n):
836-
if self.parent[u] != -1:
839+
if not self.is_root(u):
837840
self.likelihood[u] = -1
838841

839842
last_root = 0
@@ -940,8 +943,8 @@ def run_traceback(self, start, end, match):
940943
k = M - 1
941944
# Construct the matched haplotype
942945
match[:] = 0
943-
match[:start] = constants.UNKNOWN_ALLELE
944-
match[end:] = constants.UNKNOWN_ALLELE
946+
match[:start] = tskit.MISSING_DATA
947+
match[end:] = tskit.MISSING_DATA
945948
# Reset the tree.
946949
self.parent[:] = -1
947950
self.left_child[:] = -1

tsinfer/constants.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717
# along with tsinfer. If not, see <http://www.gnu.org/licenses/>.
1818
#
1919
"""
20-
Collection of constants used in tsinfer.
20+
Collection of constants used in tsinfer. We also make use of constants defined in tskit.
2121
"""
2222

23-
UNKNOWN_ALLELE = 255
24-
2523
C_ENGINE = "C"
2624
PY_ENGINE = "P"
2725

0 commit comments

Comments
 (0)