@@ -2239,33 +2239,57 @@ def test_basic_example(self):
2239
2239
ts = self .get_example_tree_sequence ()
2240
2240
n = ts .get_num_samples ()
2241
2241
result = ts .allele_frequency_spectrum (
2242
- [n ], ts .get_samples (), [0 , ts .get_sequence_length ()]
2242
+ [n ],
2243
+ ts .get_samples (),
2244
+ [0 , ts .get_sequence_length ()],
2245
+ mode = "branch" ,
2246
+ time_windows = [0 , np .inf ],
2243
2247
)
2244
- assert result .shape == (1 , n + 1 )
2248
+ assert result .shape == (1 , 1 , n + 1 )
2245
2249
result = ts .allele_frequency_spectrum (
2246
- [n ], ts .get_samples (), [0 , ts .get_sequence_length ()], polarised = True
2250
+ [n ],
2251
+ ts .get_samples (),
2252
+ [0 , ts .get_sequence_length ()],
2253
+ mode = "branch" ,
2254
+ time_windows = [0 , np .inf ],
2255
+ polarised = True ,
2247
2256
)
2248
- assert result .shape == (1 , n + 1 )
2257
+ assert result .shape == (1 , 1 , n + 1 )
2249
2258
2250
2259
def test_output_dims (self ):
2251
2260
ts = self .get_example_tree_sequence ()
2252
2261
samples = ts .get_samples ()
2253
2262
L = ts .get_sequence_length ()
2254
2263
n = len (samples )
2264
+ time_windows = [0 , np .inf ]
2255
2265
2256
2266
for mode in ["site" , "branch" ]:
2257
2267
for s in [[n ], [n - 2 , 2 ], [n - 4 , 2 , 2 ], [1 ] * n ]:
2258
2268
s = np .array (s , dtype = np .uint32 )
2259
2269
windows = [0 , L ]
2260
2270
for windows in [[0 , L ], [0 , L / 2 , L ], np .linspace (0 , L , num = 10 )]:
2261
2271
jafs = ts .allele_frequency_spectrum (
2262
- s , samples , windows , mode = mode , polarised = True
2272
+ s ,
2273
+ samples ,
2274
+ windows ,
2275
+ mode = mode ,
2276
+ time_windows = time_windows ,
2277
+ polarised = True ,
2278
+ )
2279
+ assert jafs .shape == tuple (
2280
+ [len (windows ) - 1 ] + [len (time_windows ) - 1 ] + list (s + 1 )
2263
2281
)
2264
- assert jafs .shape == tuple ([len (windows ) - 1 ] + list (s + 1 ))
2265
2282
jafs = ts .allele_frequency_spectrum (
2266
- s , samples , windows , mode = mode , polarised = False
2283
+ s ,
2284
+ samples ,
2285
+ windows ,
2286
+ mode = mode ,
2287
+ time_windows = time_windows ,
2288
+ polarised = False ,
2289
+ )
2290
+ assert jafs .shape == tuple (
2291
+ [len (windows ) - 1 ] + [len (time_windows ) - 1 ] + list (s + 1 )
2267
2292
)
2268
- assert jafs .shape == tuple ([len (windows ) - 1 ] + list (s + 1 ))
2269
2293
2270
2294
def test_node_mode_not_supported (self ):
2271
2295
ts = self .get_example_tree_sequence ()
@@ -2275,8 +2299,142 @@ def test_node_mode_not_supported(self):
2275
2299
ts .get_samples (),
2276
2300
[0 , ts .get_sequence_length ()],
2277
2301
mode = "node" ,
2302
+ time_windows = [0 , np .inf ],
2278
2303
)
2279
2304
2305
+ def test_polarised (self ):
2306
+ """
2307
+ Temporary duplicate from class OneWaySampleStatsMixin
2308
+ used to provide the time_windows argument.
2309
+ """
2310
+ # TODO move this to the top level.
2311
+ ts , method = self .get_method ()
2312
+ samples = ts .get_samples ()
2313
+ n = len (samples )
2314
+ windows = [0 , ts .get_sequence_length ()]
2315
+ method (
2316
+ [n ],
2317
+ samples ,
2318
+ windows ,
2319
+ time_windows = [0 , np .inf ],
2320
+ mode = "branch" ,
2321
+ polarised = True ,
2322
+ )
2323
+ method (
2324
+ [n ],
2325
+ samples ,
2326
+ windows ,
2327
+ time_windows = [0 , np .inf ],
2328
+ mode = "branch" ,
2329
+ polarised = False ,
2330
+ )
2331
+
2332
+ def test_polarisation (self ):
2333
+ ts , f , params = self .get_example ()
2334
+ with pytest .raises (TypeError ):
2335
+ f (polarised = "sdf" , time_windows = [0 , np .inf ], mode = "branch" , ** params )
2336
+ x1 = f (polarised = False , time_windows = [0 , np .inf ], mode = "branch" , ** params )
2337
+ x2 = f (polarised = True , time_windows = [0 , np .inf ], mode = "branch" , ** params )
2338
+ # Basic check just to run both code paths
2339
+ assert x1 .shape == x2 .shape
2340
+
2341
+ def test_mode_errors (self ):
2342
+ _ , f , params = self .get_example ()
2343
+ for bad_mode in ["" , "not a mode" , "SITE" , "x" * 8192 ]:
2344
+ with pytest .raises (ValueError ):
2345
+ f (mode = bad_mode , time_windows = [0 , np .inf ], ** params )
2346
+
2347
+ for bad_type in [123 , {}, None , [[]]]:
2348
+ with pytest .raises (TypeError ):
2349
+ f (mode = bad_type , time_windows = [0 , np .inf ], ** params )
2350
+
2351
+ def test_window_errors (self ):
2352
+ ts , f , params = self .get_example ()
2353
+ del params ["windows" ]
2354
+ for bad_array in ["asdf" , None , [[[[]], [[]]]], np .zeros ((10 , 3 , 4 ))]:
2355
+ with pytest .raises (ValueError ):
2356
+ f (windows = bad_array , time_windows = [0 , np .inf ], mode = "branch" , ** params )
2357
+
2358
+ for bad_windows in [[], [0 ]]:
2359
+ with pytest .raises (ValueError ):
2360
+ f (
2361
+ windows = bad_windows ,
2362
+ time_windows = [0 , np .inf ],
2363
+ mode = "branch" ,
2364
+ ** params ,
2365
+ )
2366
+ L = ts .get_sequence_length ()
2367
+ bad_windows = [
2368
+ [L , 0 ],
2369
+ [0.1 , L ],
2370
+ [- 1 , L ],
2371
+ [0 , L + 0.1 ],
2372
+ [0 , 0.1 , 0.1 , L ],
2373
+ [0 , - 1 , L ],
2374
+ [0 , 0.1 , 0.05 , 0.2 , L ],
2375
+ ]
2376
+ for bad_window in bad_windows :
2377
+ with pytest .raises (_tskit .LibraryError ):
2378
+ f (windows = bad_window , time_windows = [0 , np .inf ], mode = "branch" , ** params )
2379
+
2380
+ def test_windows_output (self ):
2381
+ ts , f , params = self .get_example ()
2382
+ del params ["windows" ]
2383
+ for num_windows in range (1 , 10 ):
2384
+ windows = np .linspace (0 , ts .get_sequence_length (), num = num_windows + 1 )
2385
+ assert windows .shape [0 ] == num_windows + 1
2386
+ sigma = f (
2387
+ windows = windows , time_windows = [0 , np .inf ], mode = "branch" , ** params
2388
+ )
2389
+ assert sigma .shape [0 ] == num_windows
2390
+
2391
+ def test_bad_sample_sets (self ):
2392
+ ts , f , params = self .get_example ()
2393
+ del params ["sample_set_sizes" ]
2394
+ del params ["sample_sets" ]
2395
+
2396
+ with pytest .raises (_tskit .LibraryError ):
2397
+ f (
2398
+ sample_sets = [],
2399
+ sample_set_sizes = [],
2400
+ time_windows = [0 , np .inf ],
2401
+ mode = "branch" ,
2402
+ ** params ,
2403
+ )
2404
+
2405
+ n = ts .get_num_samples ()
2406
+ samples = ts .get_samples ()
2407
+ for bad_set_sizes in [[], [1 ], [n - 1 ], [n + 1 ], [n - 3 , 1 , 1 ], [1 , n - 2 ]]:
2408
+ with pytest .raises (ValueError ):
2409
+ f (
2410
+ sample_set_sizes = bad_set_sizes ,
2411
+ sample_sets = samples ,
2412
+ time_windows = [0 , np .inf ],
2413
+ mode = "branch" ,
2414
+ ** params ,
2415
+ )
2416
+
2417
+ N = ts .get_num_nodes ()
2418
+ for bad_node in [- 1 , N , N + 1 , - N ]:
2419
+ with pytest .raises (_tskit .LibraryError ):
2420
+ f (
2421
+ sample_set_sizes = [2 ],
2422
+ sample_sets = [0 , bad_node ],
2423
+ time_windows = [0 , np .inf ],
2424
+ mode = "branch" ,
2425
+ ** params ,
2426
+ )
2427
+
2428
+ for bad_sample in [n , n + 1 , N - 1 ]:
2429
+ with pytest .raises (_tskit .LibraryError ):
2430
+ f (
2431
+ sample_set_sizes = [2 ],
2432
+ sample_sets = [0 , bad_sample ],
2433
+ time_windows = [0 , np .inf ],
2434
+ mode = "branch" ,
2435
+ ** params ,
2436
+ )
2437
+
2280
2438
2281
2439
class TwoWaySampleStatsMixin (SampleSetMixin ):
2282
2440
"""
0 commit comments