Skip to content

Commit 8958c65

Browse files
inailuigtensorflower-gardener
authored andcommitted
PR tensorflow#7849: [XLA:CPU] Add support for cross-process collectives using mpi.
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>: bump xla_extension_version Merging this change closes tensorflow#7849 PiperOrigin-RevId: 620302264
1 parent 66ee739 commit 8958c65

File tree

15 files changed

+1054
-1
lines changed

15 files changed

+1054
-1
lines changed

third_party/mpitrampoline/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])

third_party/mpitrampoline/gen.patch

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
diff --git a/gen/gen_decl.py b/gen/gen_decl.py
2+
index 1005b95..696b4e0 100755
3+
--- a/gen/gen_decl.py
4+
+++ b/gen/gen_decl.py
5+
@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))
6+
7+
from mpi_constants import constants
8+
from mpi_functions import functions
9+
-from mpi_constants_fortran import constants_fortran
10+
-from mpi_functions_fortran import functions_fortran
11+
+# from mpi_constants_fortran import constants_fortran
12+
+# from mpi_functions_fortran import functions_fortran
13+
14+
support_profiling = True
15+
have_weak_symbols = False
16+
@@ -24,7 +24,7 @@ def wrap(line):
17+
lines.append(line)
18+
return "\n".join(lines)
19+
20+
-with open("include/mpi_decl_constants_c.h", "w") as file:
21+
+with open(sys.argv[1], "w") as file:
22+
file.write("// Declare C MPI constants\n")
23+
file.write("\n")
24+
for (tp, nm) in constants:
25+
@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file:
26+
'mpi_nm': nm}
27+
file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs))
28+
29+
-with open("include/mpi_decl_functions_c.h", "w") as file:
30+
+with open(sys.argv[2], "w") as file:
31+
file.write("// Declare C MPI functions\n")
32+
file.write("\n")
33+
for (tp, nm, args, flags) in functions:
34+
@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file:
35+
file.write(Template("\n".join(tmpl)).substitute(subs))
36+
file.write("\n")
37+
38+
-with open("include/mpi_decl_constants_fortran.h", "w") as file:
39+
+if False:
40+
file.write("! Declare Fortran MPI constants\n")
41+
file.write("\n")
42+
for (tp, nm) in constants_fortran:
43+
@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file:
44+
file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl)))
45+
file.write("\n")
46+
47+
-with open("include/mpi_decl_functions_fortran.h", "w") as file:
48+
+if False:
49+
file.write("! Declare Fortran MPI functions\n")
50+
file.write("\n")
51+
for (tp, nm, args) in functions_fortran:
52+
diff --git a/gen/gen_defn.py b/gen/gen_defn.py
53+
index bf31f35..318222e 100755
54+
--- a/gen/gen_defn.py
55+
+++ b/gen/gen_defn.py
56+
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))
57+
58+
from mpi_constants import constants
59+
from mpi_functions import functions
60+
-from mpi_constants_fortran import constants_fortran
61+
-from mpi_functions_fortran import functions_fortran
62+
+# from mpi_constants_fortran import constants_fortran
63+
+# from mpi_functions_fortran import functions_fortran
64+
65+
support_profiling = True
66+
have_weak_symbols = False
67+
replace_sentinels = False
68+
69+
-with open("src/mpi_defn_constants_c.h", "w") as file:
70+
+with open(sys.argv[1], "w") as file:
71+
file.write("// Define C MPI constants")
72+
file.write("\n")
73+
for (tp, nm) in constants:
74+
@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file:
75+
'mpi_nm': nm}
76+
file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs))
77+
78+
-with open("src/mpi_defn_functions_c.h", "w") as file:
79+
+with open(sys.argv[2], "w") as file:
80+
file.write("// Define C MPI functions\n")
81+
file.write("\n")
82+
for (tp, nm, args, flags) in functions:
83+
@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file:
84+
file.write(Template("\n".join(tmpl)).substitute(subs))
85+
file.write("\n")
86+
87+
-with open("src/mpi_defn_constants_fortran.h", "w") as file:
88+
+if False:
89+
file.write("// Define Fortran MPI constants\n")
90+
file.write("\n")
91+
for (tp, nm) in constants_fortran:
92+
@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file:
93+
# Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes
94+
file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs))
95+
96+
-with open("src/mpi_defn_functions_fortran.h", "w") as file:
97+
+if False:
98+
file.write("// Define Fortran MPI functions\n")
99+
file.write("\n")
100+
for (tp, nm, args) in functions_fortran:
101+
diff --git a/gen/gen_init.py b/gen/gen_init.py
102+
index 4939261..0e52822 100755
103+
--- a/gen/gen_init.py
104+
+++ b/gen/gen_init.py
105+
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))
106+
107+
from mpi_constants import constants
108+
from mpi_functions import functions
109+
-from mpi_constants_fortran import constants_fortran
110+
-from mpi_functions_fortran import functions_fortran
111+
+# from mpi_constants_fortran import constants_fortran
112+
+# from mpi_functions_fortran import functions_fortran
113+
114+
support_profiling = True
115+
have_weak_symbols = False
116+
replace_sentinels = False
117+
118+
-with open("src/mpi_init_constants_c.h", "w") as file:
119+
+with open(sys.argv[1], "w") as file:
120+
file.write("// Initialize C MPI constants")
121+
file.write("\n")
122+
for (tp, nm) in constants:
123+
@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file:
124+
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)}
125+
file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))
126+
127+
-with open("src/mpi_init_functions_c.h", "w") as file:
128+
+with open(sys.argv[2], "w") as file:
129+
file.write("// Initialize C MPI functions\n")
130+
file.write("\n")
131+
for (tp, nm, args, flags) in functions:
132+
@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file:
133+
subs['anm{0}'.format(i)] = anm
134+
file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs))
135+
136+
-with open("src/mpi_init_constants_fortran.h", "w") as file:
137+
+if False:
138+
file.write("// Initialize Fortran MPI constants\n")
139+
file.write("\n")
140+
for (tp, nm) in constants_fortran:
141+
@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file:
142+
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"}
143+
file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))
144+
145+
-with open("src/mpi_init_functions_fortran.h", "w") as file:
146+
+if False:
147+
file.write("// Initialize Fortran MPI functions\n")
148+
file.write("\n")
149+
for (tp, nm, args) in functions_fortran:
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Description:
2+
# A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI
3+
4+
load("@org_tensorflow//xla:strict.default.bzl", "py_strict_binary")
5+
load("//third_party/bazel_skylib/rules:expand_template.bzl", "expand_template")
6+
7+
package(
8+
default_visibility = ["//visibility:public"],
9+
)
10+
11+
licenses(["notice"])
12+
13+
exports_files(["LICENSE.md"])
14+
15+
genrule(
16+
name = "mpi_version",
17+
srcs = [
18+
"CMakeLists.txt",
19+
"include/mpi_version.h.in",
20+
],
21+
outs = ["include/mpi_version.h"],
22+
cmd = """
23+
PROJECT_VERSION=`cat $(location CMakeLists.txt) \
24+
| grep "MPItrampoline VERSION" | awk '{print $$NF}'`
25+
PROJECT_VERSION_MAJOR=`echo $$PROJECT_VERSION | cut -d. -f1`
26+
PROJECT_VERSION_MINOR=`echo $$PROJECT_VERSION | cut -d. -f2`
27+
PROJECT_VERSION_PATCH=`echo $$PROJECT_VERSION | cut -d. -f3`
28+
sed -e "s/@PROJECT_VERSION@/$${PROJECT_VERSION}/" \
29+
-e "s/@PROJECT_VERSION_MAJOR@/$${PROJECT_VERSION_MAJOR}/" \
30+
-e "s/@PROJECT_VERSION_MINOR@/$${PROJECT_VERSION_MINOR}/" \
31+
-e "s/@PROJECT_VERSION_PATCH@/$${PROJECT_VERSION_PATCH}/" \
32+
$(location include/mpi_version.h.in) > $(location include/mpi_version.h)
33+
""",
34+
)
35+
36+
expand_template(
37+
name = "mpi_defaults",
38+
out = "src/mpi_defaults.h",
39+
substitutions = {
40+
"@MPITRAMPOLINE_DEFAULT_DELAY_INIT@": "",
41+
"@MPITRAMPOLINE_DEFAULT_DLOPEN_BINDING@": "",
42+
"@MPITRAMPOLINE_DEFAULT_DLOPEN_MODE@": "",
43+
"@MPITRAMPOLINE_DEFAULT_LIB@": "",
44+
"@MPITRAMPOLINE_DEFAULT_PRELOAD@": "",
45+
"@MPITRAMPOLINE_DEFAULT_VERBOSE@": "",
46+
},
47+
template = "src/mpi_defaults.h.in",
48+
)
49+
50+
py_strict_binary(
51+
name = "gen_decl",
52+
srcs = [
53+
"gen/gen_decl.py",
54+
"mpiabi/mpi_constants.py",
55+
"mpiabi/mpi_functions.py",
56+
],
57+
)
58+
59+
genrule(
60+
name = "decl",
61+
outs = [
62+
"include/mpi_decl_constants_c.h",
63+
"include/mpi_decl_functions_c.h",
64+
],
65+
cmd = "$(location :gen_decl) $(location include/mpi_decl_constants_c.h) \
66+
$(location include/mpi_decl_functions_c.h)",
67+
tools = [":gen_decl"],
68+
)
69+
70+
py_strict_binary(
71+
name = "gen_defn",
72+
srcs = [
73+
"gen/gen_defn.py",
74+
"mpiabi/mpi_constants.py",
75+
"mpiabi/mpi_functions.py",
76+
],
77+
)
78+
79+
genrule(
80+
name = "defn",
81+
outs = [
82+
"include/mpi_defn_constants_c.h",
83+
"include/mpi_defn_functions_c.h",
84+
],
85+
cmd = "$(location :gen_defn) $(location include/mpi_defn_constants_c.h) \
86+
$(location include/mpi_defn_functions_c.h)",
87+
tools = [":gen_defn"],
88+
)
89+
90+
py_strict_binary(
91+
name = "gen_init",
92+
srcs = [
93+
"gen/gen_init.py",
94+
"mpiabi/mpi_constants.py",
95+
"mpiabi/mpi_functions.py",
96+
],
97+
)
98+
99+
genrule(
100+
name = "init",
101+
outs = [
102+
"include/mpi_init_constants_c.h",
103+
"include/mpi_init_functions_c.h",
104+
],
105+
cmd = "$(location :gen_init) $(location include/mpi_init_constants_c.h) \
106+
$(location include/mpi_init_functions_c.h)",
107+
tools = [":gen_init"],
108+
)
109+
110+
cc_library(
111+
name = "mpitrampoline",
112+
srcs = [
113+
"src/mpi.c",
114+
],
115+
hdrs = [
116+
"include/mpi.h",
117+
"include/mpi_decl_constants_c.h",
118+
"include/mpi_decl_functions_c.h",
119+
"include/mpi_defn_constants_c.h",
120+
"include/mpi_defn_functions_c.h",
121+
"include/mpi_init_constants_c.h",
122+
"include/mpi_init_functions_c.h",
123+
"include/mpi_version.h",
124+
"mpiabi/mpiabi.h",
125+
"src/mpi_defaults.h",
126+
],
127+
copts = [
128+
"-fexceptions",
129+
],
130+
includes = [
131+
"include",
132+
"mpiabi",
133+
"src",
134+
],
135+
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Provides the repository macro to import mpitrampoline."""
2+
3+
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
4+
5+
def repo():
6+
"""Imports mpitrampoline."""
7+
8+
MPITRAMPOLINE_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed"
9+
MPITRAMPOLINE_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84"
10+
11+
tf_http_archive(
12+
name = "mpitrampoline",
13+
sha256 = MPITRAMPOLINE_SHA256,
14+
strip_prefix = "MPItrampoline-{commit}".format(commit = MPITRAMPOLINE_COMMIT),
15+
urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = MPITRAMPOLINE_COMMIT)),
16+
patch_file = ["//third_party/mpitrampoline:gen.patch"],
17+
build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD",
18+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])

0 commit comments

Comments
 (0)