Skip to content

Commit 0a9039e

Browse files
committed
XLA: vendor the runtime mlir backend
1 parent abd21e3 commit 0a9039e

File tree

3 files changed

+90
-9
lines changed

3 files changed

+90
-9
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
with:
3535
path: "~/.cache/bazel"
3636
key: bazel-${{ matrix.os }}
37-
- run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \;
37+
- run: sudo find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \;
3838
- run: |
3939
bazel build :enzyme_ad @llvm-project//llvm:FileCheck
4040
bazel cquery "allpaths(//src/enzyme_ad/jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps

src/enzyme_ad/jax/compile_with_xla.cc

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22
#include "xla/service/service.h"
33
#undef protected
44

5+
// Needed to access CompileXlaRuntimeCpuExecutable/etc
6+
#define private public
7+
#include "xla/service/cpu/cpu_compiler.h"
8+
#undef private
9+
10+
#include "xla/service/compiler.h"
511
#include "xla/service/cpu/cpu_executable.h"
12+
#include "xla/service/hlo_module_util.h"
13+
#include "xla/service/hlo_proto_util.h"
614
#include "xla/service/local_service_utils.h"
715

816
#include "absl/status/statusor.h"
@@ -27,6 +35,8 @@
2735
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
2836
#include "xla/translate/mhlo_to_hlo/type_to_shape.h"
2937

38+
#include "xla/statusor.h"
39+
3040
#include "pybind11/pybind11.h"
3141

3242
#include "compile_with_xla.h"
@@ -161,6 +171,80 @@ run_pass_pipeline(const std::vector<std::string> &oldsym_vec,
161171
return std::make_pair(entryfn.str(), ss.str());
162172
}
163173

174+
absl::StatusOr<std::unique_ptr<xla::Executable>>
175+
RunBackend(xla::cpu::CpuCompiler *self, std::unique_ptr<xla::HloModule> module,
176+
[[maybe_unused]] xla::se::StreamExecutor *stream_exec,
177+
const xla::Compiler::CompileOptions &options, bool xla_runtime) {
178+
179+
std::unique_ptr<xla::cpu::CpuExecutable> cpu_executable;
180+
if (xla_runtime) {
181+
TF_ASSIGN_OR_RETURN(cpu_executable,
182+
self->CompileXlaRuntimeCpuExecutable(std::move(module),
183+
options.registry));
184+
} else {
185+
TF_ASSIGN_OR_RETURN(cpu_executable,
186+
self->CompileLegacyCpuExecutable(std::move(module)));
187+
}
188+
189+
return std::unique_ptr<xla::Executable>(std::move(cpu_executable));
190+
}
191+
192+
absl::StatusOr<std::unique_ptr<xla::Executable>>
193+
BuildExecutable(xla::Service *self, const xla::HloModuleProto &module_proto,
194+
std::unique_ptr<xla::HloModuleConfig> module_config,
195+
xla::Backend *backend, xla::se::StreamExecutor *executor,
196+
const xla::Compiler::CompileOptions &options,
197+
bool run_backend_only, bool xla_runtime) {
198+
199+
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloModule> module,
200+
xla::CreateModuleFromProto(module_proto, *module_config,
201+
run_backend_only));
202+
xla::UpdateEntryComputationLayout(
203+
module.get(), std::bind(&xla::Compiler::DefaultDeviceShapeRepresentation,
204+
backend->compiler(), std::placeholders::_1));
205+
// xla::DumpHloModuleIfEnabled(*module, xla::kBeforeOptimizationsDumpName);
206+
207+
std::unique_ptr<xla::HloProto> hlo_proto_before_opt;
208+
if (!run_backend_only) {
209+
// Save proto state before optimizations if we want a snapshot.
210+
// When run_backend_only is enabled the post-optimization HLO will be the
211+
// same as the pre-optimization HLO.
212+
// if (xla::DumpingEnabledForHloModule(*module)) {
213+
// hlo_proto_before_opt =
214+
// std::make_unique<xla::HloProto>(MakeHloProto(*module));
215+
// }
216+
TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
217+
std::move(module), executor, options));
218+
}
219+
220+
/*
221+
TF_ASSIGN_OR_RETURN(
222+
std::unique_ptr<xla::Executable> executable,
223+
backend->compiler()->RunBackend(std::move(module), executor, options));
224+
*/
225+
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Executable> executable,
226+
RunBackend((xla::cpu::CpuCompiler *)backend->compiler(),
227+
std::move(module), executor, options,
228+
xla_runtime));
229+
230+
const xla::BufferAssignmentProto *buffer_assignment_proto_after_opt =
231+
executable->buffer_assignment_proto();
232+
233+
// If dumping is enabled RunBackend(...) will emit a hlo_proto in the
234+
// executable. This contains the buffer_assignment that is only available
235+
// after RunBackend(). If hlo_proto_before_opt is not null, then we replace
236+
// its buffer_assignment with the one from after_opt and then store it into
237+
// the executable.
238+
if (hlo_proto_before_opt != nullptr &&
239+
buffer_assignment_proto_after_opt != nullptr) {
240+
// CHECK(xla::DumpingEnabledForHloModule(executable->module()));
241+
*hlo_proto_before_opt->mutable_buffer_assignment() =
242+
std::move(*buffer_assignment_proto_after_opt);
243+
executable->set_hlo_proto(std::move(hlo_proto_before_opt));
244+
}
245+
return std::move(executable);
246+
}
247+
164248
// Compile an MHLO module given as a string to LLVM IR using XLA.
165249
std::unique_ptr<xla::LocalExecutable>
166250
compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
@@ -288,10 +372,11 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
288372
build_options.device_allocator(), build_options.compile_thread_pool(),
289373
build_options.layout_canonicalization_callback()};
290374
opts.registry = &registry;
291-
auto executable = local_client->local_service()->BuildExecutable(
292-
xla_computation.proto(), std::move(module_config_or_error.value()),
293-
local_client->mutable_backend(), executor.value(), opts,
294-
build_options.run_backend_only());
375+
auto executable =
376+
BuildExecutable(local_client->local_service(), xla_computation.proto(),
377+
std::move(module_config_or_error.value()),
378+
local_client->mutable_backend(), executor.value(), opts,
379+
build_options.run_backend_only(), xla_runtime);
295380
if (!executable.ok()) {
296381
throw pybind11::value_error(executable.status().ToString());
297382
}

src/enzyme_ad/jax/primitives.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,8 @@ def __init__(self, passes=None, mlirad=False):
9696
region-simplify=true
9797
test-convergence=false
9898
top-down=true},
99-
func.func(xla-sparse-custom-call-to-pack),
100-
func.func(legalize-sparse-ops{legalize-to-custom-calls=false}),
10199
func.func(chlo-legalize-to-hlo{
102100
expand-compositions=true legalize-broadcasts=true}),
103-
func.func(mhlo-sparse-rewriting),
104101
func.func(mhlo-legalize-control-flow),
105102
func.func(mhlo-legalize-dot-general-to-dot),
106103
hlo-legalize-to-arithmetic,
@@ -200,7 +197,6 @@ def __init__(self, passes=None, mlirad=False):
200197
test-convergence=false
201198
top-down=true},
202199
cse,
203-
func.func(xla-math-approximation{oplist=all}),
204200
func.func(convert-linalg-to-parallel-loops),
205201
canonicalize{
206202
max-iterations=10

0 commit comments

Comments
 (0)