-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathconftest.py
198 lines (170 loc) · 5.66 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import shutil
import pytest
import psutil
import smartsim
from smartsim.database import (
CobaltOrchestrator, SlurmOrchestrator,
PBSOrchestrator, Orchestrator,
LSFOrchestrator
)
from smartsim.settings import (
SrunSettings, AprunSettings,
JsrunSettings, RunSettings
)
from smartsim.config import CONFIG
# Globals, yes, but its a testing file
test_path = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(test_path, "tests", "test_output")
test_launcher = CONFIG.test_launcher
test_device = CONFIG.test_device
test_nic = CONFIG.test_interface
def get_account():
global test_account
test_account = CONFIG.test_account
return test_account
def print_test_configuration():
global test_path
global test_dir
global test_launcher
global test_account
global test_nic
print("TEST_SMARTSIM_LOCATION:", smartsim.__path__)
print("TEST_PATH:", test_path)
print("TEST_LAUNCHER:", test_launcher)
if test_account != "":
print("TEST_ACCOUNT:", test_account)
print("TEST_DEVICE:", test_device)
print("TEST_NETWORK_INTERFACE (WLM only):", test_nic)
print("TEST_DIR:", test_dir)
print("Test output will be located in TEST_DIR if there is a failure")
def pytest_configure():
global test_launcher
pytest.test_launcher = test_launcher
pytest.wlm_options = ["slurm", "pbs", "cobalt", "lsf"]
account = get_account()
pytest.test_account = account
def pytest_sessionstart(session):
"""
Called after the Session object has been created and
before performing collection and entering the run test loop.
"""
if os.path.isdir(test_dir):
shutil.rmtree(test_dir)
os.mkdir(test_dir)
print_test_configuration()
def pytest_sessionfinish(session, exitstatus):
"""
Called after whole test run finished, right before
returning the exit status to the system.
"""
if exitstatus == 0:
shutil.rmtree(test_dir)
else:
# kill all spawned processes in case of error
kill_all_test_spawned_processes()
def kill_all_test_spawned_processes():
# in case of test failure, clean up all spawned processes
pid = os.getpid()
try:
parent = psutil.Process(pid)
except psutil.Error:
# could not find parent process id
return
try:
for child in parent.children(recursive=True):
child.kill()
except Exception:
print("Not all processes were killed after test")
@pytest.fixture
def wlmutils():
return WLMUtils
class WLMUtils:
@staticmethod
def get_test_launcher():
global test_launcher
return test_launcher
@staticmethod
def get_test_account():
global test_account
return test_account
@staticmethod
def get_test_interface():
global test_nic
return test_nic
@staticmethod
def get_run_settings(exe, args, nodes=1, ntasks=1, **kwargs):
if test_launcher == "slurm":
run_args = {"nodes": nodes,
"ntasks": ntasks,
"time": "00:10:00"}
run_args.update(kwargs)
settings = SrunSettings(exe, args, run_args=run_args)
return settings
elif test_launcher == "pbs":
run_args = {"pes": ntasks}
run_args.update(kwargs)
settings = AprunSettings(exe, args, run_args=run_args)
return settings
# TODO allow user to pick aprun vs MPIrun
elif test_launcher == "cobalt":
run_args = {"pes": ntasks}
run_args.update(kwargs)
settings = AprunSettings(exe, args, run_args=run_args)
return settings
if test_launcher == "lsf":
run_args = {"nrs": nodes,
"tasks_per_rs": max(ntasks//nodes,1),
}
run_args.update(kwargs)
settings = JsrunSettings(exe, args, run_args=run_args)
return settings
else:
return RunSettings(exe, args)
@staticmethod
def get_orchestrator(nodes=1, port=6780, batch=False):
global test_launcher
global test_nic
if test_launcher == "slurm":
db = SlurmOrchestrator(db_nodes=nodes, port=port, batch=batch, interface=test_nic)
elif test_launcher == "pbs":
db = PBSOrchestrator(db_nodes=nodes, port=port, batch=batch, interface=test_nic)
elif test_launcher == "cobalt":
db = CobaltOrchestrator(db_nodes=nodes, port=port, batch=batch, interface=test_nic)
elif test_launcher == "lsf":
db = LSFOrchestrator(db_nodes=nodes, port=port, batch=batch, gpus_per_shard=1 if test_device=="GPU" else 0, project=get_account(), interface=test_nic)
else:
db = Orchestrator(port=port, interface="lo")
return db
@pytest.fixture
def fileutils():
return FileUtils
class FileUtils:
@staticmethod
def get_test_dir(dir_name):
dir_path = os.path.join(test_dir, dir_name)
return dir_path
@staticmethod
def make_test_dir(dir_name):
dir_path = os.path.join(test_dir, dir_name)
try:
os.mkdir(dir_path)
except Exception:
return dir_path
return dir_path
@staticmethod
def get_test_conf_path(filename):
file_path = os.path.join(test_path, "tests", "test_configs", filename)
return file_path
@staticmethod
def get_test_dir_path(dirname):
dir_path = os.path.join(test_path, "tests", "test_configs", dirname)
return dir_path
@pytest.fixture
def mlutils():
return MLUtils
class MLUtils:
@staticmethod
def get_test_device():
global test_device
return test_device