Skip to content

Commit 6137ac9

Browse files
giordanoenzyme-ci-bot[bot]wsmoses
authored
Update JAX to commit 45065d569064392cee45a066f9f788fe29ee2cd8 (#1418)
* Update JAX to commit 45065d569064392cee45a066f9f788fe29ee2cd8 Diff: jax-ml/jax@4455434...45065d5 * Drop unneeded patch * fixup * Fix * fix * fix tests * fix * add missing file * fix * fix * fixup * bug workaround for gcc * fix * fix * more patches * fix * fix * fix --------- Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com> Co-authored-by: William S. Moses <gh@wsmoses.com>
1 parent 996b3e5 commit 6137ac9

21 files changed

+118
-206
lines changed

WORKSPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
3333
# ],
3434
# )
3535

36+
load("//third_party/ml_toolchain:workspace.bzl", ml_toolchain_workspace = "repo")
37+
38+
ml_toolchain_workspace()
39+
3640
load("//third_party/jax:workspace.bzl", jax_workspace = "repo")
3741

3842
jax_workspace()

patches/xla.patch

Lines changed: 0 additions & 101 deletions
This file was deleted.

src/enzyme_ad/jax/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,6 @@ cc_library(
979979
"@triton//third_party/nvidia:NVWSDialect",
980980
"@triton//third_party/nvidia:NVWSTransforms",
981981
"@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
982-
"@triton//third_party/proton:ProtonIRDialect",
983982

984983
# Shardy stuff
985984
"@shardy//shardy/dialect/sdy/ir:dialect",

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2832,10 +2832,15 @@ struct WhileOpEnzymeOpsRemover
28322832
auto zero = makeI64Constant(whileOp->getLoc(), rewriter, 0);
28332833

28342834
// Run min cut partitioning to limit the amount of values to be cached.
2835-
if (!caches.empty() && !whileOp->hasAttr("enzymexla.disable_min_cut")) {
2835+
if (hasMinCut(whileOp) && caches.size()) {
28362836
Block *forward = &whileOp.getBody().front();
28372837
Block *reverse = &otherWhileOp.getBody().front();
2838-
mlir::enzyme::minCutCache(forward, reverse, caches, rewriter);
2838+
Operation *lastFwd = nullptr;
2839+
IRMapping fwdrevmap;
2840+
OpBuilder::InsertionGuard guard(rewriter);
2841+
rewriter.setInsertionPointToStart(reverse);
2842+
mlir::enzyme::minCutCache(forward, reverse, caches, rewriter, fwdrevmap,
2843+
lastFwd);
28392844
}
28402845

28412846
Value itersV = nullptr;

src/enzyme_ad/jax/clang_compile.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,8 @@ struct tensor<T, n0, N...>
480480
fuseFS->pushOverlay(fs);
481481
fuseFS->pushOverlay(baseFS);
482482

483-
Clang->createFileManager(fuseFS);
483+
Clang->createVirtualFileSystem(fuseFS);
484+
Clang->createFileManager();
484485

485486
bool Success = CompilerInvocation::CreateFromArgs(
486487
Clang->getInvocation(), Argv.getArguments(), Diags, binary);
@@ -492,7 +493,7 @@ struct tensor<T, n0, N...>
492493
CompilerInvocation::GetResourcesPath(binary, /*MainAddr*/ 0x0);
493494

494495
// Create the actual diagnostics engine.
495-
Clang->createDiagnostics(*fuseFS);
496+
Clang->setDiagnostics(Clang->createDiagnostics(*fuseFS, DiagOpts));
496497
if (!Clang->hasDiagnostics()) {
497498
llvm::errs() << " failed create diag\n";
498499
return {};

src/enzyme_ad/jax/compile_with_xla.cc

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -205,31 +205,12 @@ run_pass_pipeline(const std::vector<std::string> &oldsym_vec,
205205
return std::make_pair(entryfn.str(), ss.str());
206206
}
207207

208-
absl::StatusOr<std::unique_ptr<xla::Executable>>
209-
RunBackend(xla::cpu::CpuCompiler *self, std::unique_ptr<xla::HloModule> module,
210-
[[maybe_unused]] xla::se::StreamExecutor *stream_exec,
211-
const xla::Compiler::CompileOptions &options, bool xla_runtime) {
212-
213-
std::unique_ptr<xla::cpu::CpuExecutable> cpu_executable;
214-
if (xla_runtime) {
215-
throw nanobind::value_error("xla_runtime deprecated upstream");
216-
// TF_ASSIGN_OR_RETURN(cpu_executable,
217-
// self->CompileXlaRuntimeCpuExecutable(std::move(module),
218-
// options.registry));
219-
} else {
220-
TF_ASSIGN_OR_RETURN(cpu_executable,
221-
self->CompileCpuExecutable(std::move(module)));
222-
}
223-
224-
return std::unique_ptr<xla::Executable>(std::move(cpu_executable));
225-
}
226-
227208
absl::StatusOr<std::unique_ptr<xla::Executable>>
228209
BuildExecutable(xla::Service *self, const xla::HloModuleProto &module_proto,
229210
std::unique_ptr<xla::HloModuleConfig> module_config,
230211
xla::Backend *backend, xla::se::StreamExecutor *executor,
231212
const xla::Compiler::CompileOptions &options,
232-
bool run_backend_only, bool xla_runtime) {
213+
bool run_backend_only, [[maybe_unused]] bool xla_runtime) {
233214

234215
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloModule> module,
235216
xla::CreateModuleFromProto(module_proto, *module_config,
@@ -252,15 +233,9 @@ BuildExecutable(xla::Service *self, const xla::HloModuleProto &module_proto,
252233
std::move(module), executor, options));
253234
}
254235

255-
/*
256236
TF_ASSIGN_OR_RETURN(
257237
std::unique_ptr<xla::Executable> executable,
258238
backend->compiler()->RunBackend(std::move(module), executor, options));
259-
*/
260-
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Executable> executable,
261-
RunBackend((xla::cpu::CpuCompiler *)backend->compiler(),
262-
std::move(module), executor, options,
263-
xla_runtime));
264239

265240
const xla::BufferAssignmentProto *buffer_assignment_proto_after_opt =
266241
executable->buffer_assignment_proto();

0 commit comments

Comments
 (0)