-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconftest.py
96 lines (63 loc) · 2.45 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from __future__ import annotations
from itertools import product
import numpy as np
import pytest
import torch
from scipy.constants import c
def pytest_addoption(parser):
parser.addoption("--all", action="store_true", help="run all combinations")
@pytest.fixture
def allcomb(request):
return request.config.getoption("--all")
carrier_frequency_list = [2.4e9, c]
Nfft_list = [64, 31]
num_carriers_list = [1, 16]
num_symbols_list = [1, 16]
num_ant_tx_list = [1, 8]
num_ant_rx_list = [1, 8]
scs_list = [1.0, 15e3]
cp_frac_list = np.linspace(1e-2, 1, 5)
num_guard_carriers_list = list(product([0, 1, 8], [0, 1, 8]))
num_paths_list = [3, 1, 0]
dc_null_list = [True, False]
device_list = [torch.device("cpu")] + [torch.device(i) for i in range(torch.cuda.device_count())]
def get_params(request, allcomb, default):
if allcomb:
return request.param
else:
if request.param != default:
pytest.skip("All combinations will be run with --all option")
return default
@pytest.fixture(params=carrier_frequency_list)
def carrier_frequency(request, allcomb):
return get_params(request, allcomb, carrier_frequency_list[0])
@pytest.fixture(params=Nfft_list)
def Nfft(request, allcomb):
return get_params(request, allcomb, Nfft_list[0])
@pytest.fixture(params=num_carriers_list)
def num_carriers(request, allcomb):
return get_params(request, allcomb, num_carriers_list[0])
@pytest.fixture(params=num_symbols_list)
def num_symbols(request, allcomb):
return get_params(request, allcomb, num_symbols_list[0])
@pytest.fixture(params=num_ant_tx_list)
def num_ant_tx(request, allcomb):
return get_params(request, allcomb, num_ant_tx_list[0])
@pytest.fixture(params=num_ant_rx_list)
def num_ant_rx(request, allcomb):
return get_params(request, allcomb, num_ant_rx_list[0])
@pytest.fixture(params=scs_list)
def scs(request, allcomb):
return get_params(request, allcomb, scs_list[0])
@pytest.fixture(params=cp_frac_list)
def cp_frac(request, allcomb):
return get_params(request, allcomb, cp_frac_list[0])
@pytest.fixture(params=num_guard_carriers_list)
def num_guard_carriers(request, allcomb):
return get_params(request, allcomb, num_guard_carriers_list[0])
@pytest.fixture(params=num_paths_list)
def num_paths(request, allcomb):
return get_params(request, allcomb, num_paths_list[0])
@pytest.fixture(params=device_list)
def device(request, allcomb):
return get_params(request, allcomb, device_list[0])