forked from ivy-llc/ivy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
write_array_api_tests_k_flag.py
76 lines (68 loc) · 2.26 KB
/
write_array_api_tests_k_flag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
this_dir = os.path.dirname(os.path.realpath(__file__))
func_folder = os.path.join(this_dir, "array_api_methods_to_test")
# api function filepaths
func_fnames = os.listdir(func_folder)
func_fnames.sort()
func_fpaths = [os.path.join(func_folder, fname) for fname in func_fnames]
# all filepaths
fpaths = func_fpaths
# test lists
framework_tests_to_run = {
"jax": list(),
"numpy": list(),
"torch": list(),
"tensorflow": list(),
}
framework_tests_to_skip = {
"jax": list(),
"numpy": list(),
"torch": list(),
"tensorflow": list(),
}
# add from each filepath
for fpath in fpaths:
# extract contents
with open(fpath, "r") as file:
contents = file.read()
# update tests to run and skip
contents = [line.replace("__", "") for line in contents.split("\n")]
for framework in framework_tests_to_run:
tests_to_run = list()
tests_to_skip = list()
for s in contents:
if s == "":
continue
if ("#" not in s) or (
"#" in s
and not (framework in s.lower())
and any(f in s.lower() for f in framework_tests_to_run)
):
tests_to_run += (
["test_" + s]
if ("#" not in s)
else ["test_" + s.split("#")[1].split(" ")[0]]
)
else:
tests_to_skip += ["test_" + s[1:].split(" ")[0]]
framework_tests_to_run[framework] += tests_to_run
framework_tests_to_skip[framework] += tests_to_skip
for framework in framework_tests_to_skip:
# prune tests to skip
framework_tests_to_skip[framework] = [
tts
for tts in framework_tests_to_skip[framework]
if not max([tts in ttr for ttr in framework_tests_to_run[framework]])
]
# save to file
for framework in framework_tests_to_run:
with open(
os.path.join(this_dir, ".array_api_tests_k_flag_" + framework), "w+"
) as file:
file.write(
"("
+ " or ".join(framework_tests_to_run[framework])
+ ") and not ("
+ " or ".join(framework_tests_to_skip[framework])
+ ")"
)