Skip to content

Commit 6b3ab4f

Browse files
committed
Adapting some tests with new time_windows parameter
1 parent da5f205 commit 6b3ab4f

File tree

1 file changed

+166
-8
lines changed

1 file changed

+166
-8
lines changed

python/tests/test_lowlevel.py

Lines changed: 166 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,33 +2239,57 @@ def test_basic_example(self):
22392239
ts = self.get_example_tree_sequence()
22402240
n = ts.get_num_samples()
22412241
result = ts.allele_frequency_spectrum(
2242-
[n], ts.get_samples(), [0, ts.get_sequence_length()]
2242+
[n],
2243+
ts.get_samples(),
2244+
[0, ts.get_sequence_length()],
2245+
mode="branch",
2246+
time_windows=[0, np.inf],
22432247
)
2244-
assert result.shape == (1, n + 1)
2248+
assert result.shape == (1, 1, n + 1)
22452249
result = ts.allele_frequency_spectrum(
2246-
[n], ts.get_samples(), [0, ts.get_sequence_length()], polarised=True
2250+
[n],
2251+
ts.get_samples(),
2252+
[0, ts.get_sequence_length()],
2253+
mode="branch",
2254+
time_windows=[0, np.inf],
2255+
polarised=True,
22472256
)
2248-
assert result.shape == (1, n + 1)
2257+
assert result.shape == (1, 1, n + 1)
22492258

22502259
def test_output_dims(self):
22512260
ts = self.get_example_tree_sequence()
22522261
samples = ts.get_samples()
22532262
L = ts.get_sequence_length()
22542263
n = len(samples)
2264+
time_windows = [0, np.inf]
22552265

22562266
for mode in ["site", "branch"]:
22572267
for s in [[n], [n - 2, 2], [n - 4, 2, 2], [1] * n]:
22582268
s = np.array(s, dtype=np.uint32)
22592269
windows = [0, L]
22602270
for windows in [[0, L], [0, L / 2, L], np.linspace(0, L, num=10)]:
22612271
jafs = ts.allele_frequency_spectrum(
2262-
s, samples, windows, mode=mode, polarised=True
2272+
s,
2273+
samples,
2274+
windows,
2275+
mode=mode,
2276+
time_windows=time_windows,
2277+
polarised=True,
2278+
)
2279+
assert jafs.shape == tuple(
2280+
[len(windows) - 1] + [len(time_windows) - 1] + list(s + 1)
22632281
)
2264-
assert jafs.shape == tuple([len(windows) - 1] + list(s + 1))
22652282
jafs = ts.allele_frequency_spectrum(
2266-
s, samples, windows, mode=mode, polarised=False
2283+
s,
2284+
samples,
2285+
windows,
2286+
mode=mode,
2287+
time_windows=time_windows,
2288+
polarised=False,
2289+
)
2290+
assert jafs.shape == tuple(
2291+
[len(windows) - 1] + [len(time_windows) - 1] + list(s + 1)
22672292
)
2268-
assert jafs.shape == tuple([len(windows) - 1] + list(s + 1))
22692293

22702294
def test_node_mode_not_supported(self):
22712295
ts = self.get_example_tree_sequence()
@@ -2275,8 +2299,142 @@ def test_node_mode_not_supported(self):
22752299
ts.get_samples(),
22762300
[0, ts.get_sequence_length()],
22772301
mode="node",
2302+
time_windows=[0, np.inf],
22782303
)
22792304

2305+
def test_polarised(self):
2306+
"""
2307+
Temporary duplicate from class OneWaySampleStatsMixin
2308+
used to provide the time_windows argument.
2309+
"""
2310+
# TODO move this to the top level.
2311+
ts, method = self.get_method()
2312+
samples = ts.get_samples()
2313+
n = len(samples)
2314+
windows = [0, ts.get_sequence_length()]
2315+
method(
2316+
[n],
2317+
samples,
2318+
windows,
2319+
time_windows=[0, np.inf],
2320+
mode="branch",
2321+
polarised=True,
2322+
)
2323+
method(
2324+
[n],
2325+
samples,
2326+
windows,
2327+
time_windows=[0, np.inf],
2328+
mode="branch",
2329+
polarised=False,
2330+
)
2331+
2332+
def test_polarisation(self):
2333+
ts, f, params = self.get_example()
2334+
with pytest.raises(TypeError):
2335+
f(polarised="sdf", time_windows=[0, np.inf], mode="branch", **params)
2336+
x1 = f(polarised=False, time_windows=[0, np.inf], mode="branch", **params)
2337+
x2 = f(polarised=True, time_windows=[0, np.inf], mode="branch", **params)
2338+
# Basic check just to run both code paths
2339+
assert x1.shape == x2.shape
2340+
2341+
def test_mode_errors(self):
2342+
_, f, params = self.get_example()
2343+
for bad_mode in ["", "not a mode", "SITE", "x" * 8192]:
2344+
with pytest.raises(ValueError):
2345+
f(mode=bad_mode, time_windows=[0, np.inf], **params)
2346+
2347+
for bad_type in [123, {}, None, [[]]]:
2348+
with pytest.raises(TypeError):
2349+
f(mode=bad_type, time_windows=[0, np.inf], **params)
2350+
2351+
def test_window_errors(self):
2352+
ts, f, params = self.get_example()
2353+
del params["windows"]
2354+
for bad_array in ["asdf", None, [[[[]], [[]]]], np.zeros((10, 3, 4))]:
2355+
with pytest.raises(ValueError):
2356+
f(windows=bad_array, time_windows=[0, np.inf], mode="branch", **params)
2357+
2358+
for bad_windows in [[], [0]]:
2359+
with pytest.raises(ValueError):
2360+
f(
2361+
windows=bad_windows,
2362+
time_windows=[0, np.inf],
2363+
mode="branch",
2364+
**params,
2365+
)
2366+
L = ts.get_sequence_length()
2367+
bad_windows = [
2368+
[L, 0],
2369+
[0.1, L],
2370+
[-1, L],
2371+
[0, L + 0.1],
2372+
[0, 0.1, 0.1, L],
2373+
[0, -1, L],
2374+
[0, 0.1, 0.05, 0.2, L],
2375+
]
2376+
for bad_window in bad_windows:
2377+
with pytest.raises(_tskit.LibraryError):
2378+
f(windows=bad_window, time_windows=[0, np.inf], mode="branch", **params)
2379+
2380+
def test_windows_output(self):
2381+
ts, f, params = self.get_example()
2382+
del params["windows"]
2383+
for num_windows in range(1, 10):
2384+
windows = np.linspace(0, ts.get_sequence_length(), num=num_windows + 1)
2385+
assert windows.shape[0] == num_windows + 1
2386+
sigma = f(
2387+
windows=windows, time_windows=[0, np.inf], mode="branch", **params
2388+
)
2389+
assert sigma.shape[0] == num_windows
2390+
2391+
def test_bad_sample_sets(self):
2392+
ts, f, params = self.get_example()
2393+
del params["sample_set_sizes"]
2394+
del params["sample_sets"]
2395+
2396+
with pytest.raises(_tskit.LibraryError):
2397+
f(
2398+
sample_sets=[],
2399+
sample_set_sizes=[],
2400+
time_windows=[0, np.inf],
2401+
mode="branch",
2402+
**params,
2403+
)
2404+
2405+
n = ts.get_num_samples()
2406+
samples = ts.get_samples()
2407+
for bad_set_sizes in [[], [1], [n - 1], [n + 1], [n - 3, 1, 1], [1, n - 2]]:
2408+
with pytest.raises(ValueError):
2409+
f(
2410+
sample_set_sizes=bad_set_sizes,
2411+
sample_sets=samples,
2412+
time_windows=[0, np.inf],
2413+
mode="branch",
2414+
**params,
2415+
)
2416+
2417+
N = ts.get_num_nodes()
2418+
for bad_node in [-1, N, N + 1, -N]:
2419+
with pytest.raises(_tskit.LibraryError):
2420+
f(
2421+
sample_set_sizes=[2],
2422+
sample_sets=[0, bad_node],
2423+
time_windows=[0, np.inf],
2424+
mode="branch",
2425+
**params,
2426+
)
2427+
2428+
for bad_sample in [n, n + 1, N - 1]:
2429+
with pytest.raises(_tskit.LibraryError):
2430+
f(
2431+
sample_set_sizes=[2],
2432+
sample_sets=[0, bad_sample],
2433+
time_windows=[0, np.inf],
2434+
mode="branch",
2435+
**params,
2436+
)
2437+
22802438

22812439
class TwoWaySampleStatsMixin(SampleSetMixin):
22822440
"""

0 commit comments

Comments
 (0)