Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
# ],
# )

load("//third_party/ml_toolchain:workspace.bzl", ml_toolchain_workspace = "repo")

ml_toolchain_workspace()

load("//third_party/jax:workspace.bzl", jax_workspace = "repo")

jax_workspace()
Expand Down
101 changes: 0 additions & 101 deletions patches/xla.patch

This file was deleted.

1 change: 0 additions & 1 deletion src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,6 @@ cc_library(
"@triton//third_party/nvidia:NVWSDialect",
"@triton//third_party/nvidia:NVWSTransforms",
"@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
"@triton//third_party/proton:ProtonIRDialect",

# Shardy stuff
"@shardy//shardy/dialect/sdy/ir:dialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2832,10 +2832,15 @@ struct WhileOpEnzymeOpsRemover
auto zero = makeI64Constant(whileOp->getLoc(), rewriter, 0);

// Run min cut partitioning to limit the amount of values to be cached.
if (!caches.empty() && !whileOp->hasAttr("enzymexla.disable_min_cut")) {
if (hasMinCut(whileOp) && caches.size()) {
Block *forward = &whileOp.getBody().front();
Block *reverse = &otherWhileOp.getBody().front();
mlir::enzyme::minCutCache(forward, reverse, caches, rewriter);
Operation *lastFwd = nullptr;
IRMapping fwdrevmap;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(reverse);
mlir::enzyme::minCutCache(forward, reverse, caches, rewriter, fwdrevmap,
lastFwd);
}

Value itersV = nullptr;
Expand Down
5 changes: 3 additions & 2 deletions src/enzyme_ad/jax/clang_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ struct tensor<T, n0, N...>
fuseFS->pushOverlay(fs);
fuseFS->pushOverlay(baseFS);

Clang->createFileManager(fuseFS);
Clang->createVirtualFileSystem(fuseFS);
Clang->createFileManager();

bool Success = CompilerInvocation::CreateFromArgs(
Clang->getInvocation(), Argv.getArguments(), Diags, binary);
Expand All @@ -492,7 +493,7 @@ struct tensor<T, n0, N...>
CompilerInvocation::GetResourcesPath(binary, /*MainAddr*/ 0x0);

// Create the actual diagnostics engine.
Clang->createDiagnostics(*fuseFS);
Clang->setDiagnostics(Clang->createDiagnostics(*fuseFS, DiagOpts));
if (!Clang->hasDiagnostics()) {
llvm::errs() << " failed create diag\n";
return {};
Expand Down
27 changes: 1 addition & 26 deletions src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,31 +205,12 @@ run_pass_pipeline(const std::vector<std::string> &oldsym_vec,
return std::make_pair(entryfn.str(), ss.str());
}

absl::StatusOr<std::unique_ptr<xla::Executable>>
RunBackend(xla::cpu::CpuCompiler *self, std::unique_ptr<xla::HloModule> module,
[[maybe_unused]] xla::se::StreamExecutor *stream_exec,
const xla::Compiler::CompileOptions &options, bool xla_runtime) {

std::unique_ptr<xla::cpu::CpuExecutable> cpu_executable;
if (xla_runtime) {
throw nanobind::value_error("xla_runtime deprecated upstream");
// TF_ASSIGN_OR_RETURN(cpu_executable,
// self->CompileXlaRuntimeCpuExecutable(std::move(module),
// options.registry));
} else {
TF_ASSIGN_OR_RETURN(cpu_executable,
self->CompileCpuExecutable(std::move(module)));
}

return std::unique_ptr<xla::Executable>(std::move(cpu_executable));
}

absl::StatusOr<std::unique_ptr<xla::Executable>>
BuildExecutable(xla::Service *self, const xla::HloModuleProto &module_proto,
std::unique_ptr<xla::HloModuleConfig> module_config,
xla::Backend *backend, xla::se::StreamExecutor *executor,
const xla::Compiler::CompileOptions &options,
bool run_backend_only, bool xla_runtime) {
bool run_backend_only, [[maybe_unused]] bool xla_runtime) {

TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloModule> module,
xla::CreateModuleFromProto(module_proto, *module_config,
Expand All @@ -252,15 +233,9 @@ BuildExecutable(xla::Service *self, const xla::HloModuleProto &module_proto,
std::move(module), executor, options));
}

/*
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::Executable> executable,
backend->compiler()->RunBackend(std::move(module), executor, options));
*/
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Executable> executable,
RunBackend((xla::cpu::CpuCompiler *)backend->compiler(),
std::move(module), executor, options,
xla_runtime));

const xla::BufferAssignmentProto *buffer_assignment_proto_after_opt =
executable->buffer_assignment_proto();
Expand Down
Loading
Loading