forked from pytorch/builder
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheck_binary_symbols.py
executable file
·91 lines (81 loc) · 3.59 KB
/
check_binary_symbols.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
import concurrent.futures
import distutils.sysconfig
import itertools
import functools
import os
import re
from pathlib import Path
from typing import Any, List, Tuple
# We also check that there are [not] cxx11 symbols in libtorch
#
# To check whether it is using cxx11 ABI, check non-existence of symbol:
PRE_CXX11_SYMBOLS=(
"std::basic_string<",
"std::list",
)
# To check whether it is using pre-cxx11 ABI, check non-existence of symbol:
CXX11_SYMBOLS=(
"std::__cxx11::basic_string",
"std::__cxx11::list",
)
# NOTE: Checking the above symbols in all namespaces doesn't work, because
# devtoolset7 always produces some cxx11 symbols even if we build with old ABI,
# and CuDNN always has pre-cxx11 symbols even if we build with new ABI using gcc 5.4.
# Instead, we *only* check the above symbols in the following namespaces:
LIBTORCH_NAMESPACE_LIST=(
"c10::",
"at::",
"caffe2::",
"torch::",
)
LIBTORCH_CXX11_PATTERNS = [re.compile(f"{x}.*{y}") for (x,y) in itertools.product(LIBTORCH_NAMESPACE_LIST, CXX11_SYMBOLS)]
LIBTORCH_PRE_CXX11_PATTERNS = [re.compile(f"{x}.*{y}") for (x,y) in itertools.product(LIBTORCH_NAMESPACE_LIST, PRE_CXX11_SYMBOLS)]
@functools.lru_cache(100)
def get_symbols(lib :str ) -> List[Tuple[str, str, str]]:
from subprocess import check_output
lines = check_output(f'nm "{lib}"|c++filt', shell=True)
return [x.split(' ', 2) for x in lines.decode('latin1').split('\n')[:-1]]
def grep_symbols(lib: str, patterns: List[Any]) -> List[str]:
def _grep_symbols(symbols: List[Tuple[str, str, str]], patterns: List[Any]) -> List[str]:
rc = []
for s_addr, s_type, s_name in symbols:
for pattern in patterns:
if pattern.match(s_name):
rc.append(s_name)
continue
return rc
all_symbols = get_symbols(lib)
num_workers= 32
chunk_size = (len(all_symbols) + num_workers - 1 ) // num_workers
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
tasks = [executor.submit(_grep_symbols, all_symbols[i * chunk_size : (i + 1) * chunk_size], patterns) for i in range(num_workers)]
return sum((x.result() for x in tasks), [])
def check_lib_symbols_for_abi_correctness(lib: str, pre_cxx11_abi: bool = True) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
pre_cxx11_symbols = grep_symbols(lib, LIBTORCH_PRE_CXX11_PATTERNS)
num_cxx11_symbols = len(cxx11_symbols)
num_pre_cxx11_symbols = len(pre_cxx11_symbols)
print(f"num_cxx11_symbols: {num_cxx11_symbols}")
print(f"num_pre_cxx11_symbols: {num_pre_cxx11_symbols}")
if pre_cxx11_abi:
if num_cxx11_symbols > 0:
raise RuntimeError(f"Found cxx11 symbols, but there shouldn't be any, see: {cxx11_symbols[:100]}")
if num_pre_cxx11_symbols < 1000:
raise RuntimeError("Didn't find enough pre-cxx11 symbols.")
else:
if num_pre_cxx11_symbols > 0:
raise RuntimeError(f"Found pre-cxx11 symbols, but there shouldn't be any, see: {pre_cxx11_symbols[:100]}")
if num_cxx11_symbols < 100:
raise RuntimeError("Didn't find enought cxx11 symbols")
def main() -> None:
if os.getenv("PACKAGE_TYPE") == "libtorch":
install_root = Path(os.getcwd())
else:
install_root = Path(distutils.sysconfig.get_python_lib()) / "torch"
libtorch_cpu_path = install_root / "lib" / "libtorch_cpu.so"
pre_cxx11_abi = "cxx11-abi" not in os.getenv("DESIRED_DEVTOOLSET", "")
check_lib_symbols_for_abi_correctness(libtorch_cpu_path, pre_cxx11_abi)
if __name__ == "__main__":
main()