From 4d0ad05d4091971669cf78847f61c5fb5366daa4 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 15 Aug 2024 09:06:32 -0700 Subject: [PATCH] PR #15998: [GPU] Fix kernel cache for loaded executables. Imported from GitHub PR https://github.com/openxla/xla/pull/15998 LoadCache() has to be called also in GpuThunkAotCompilationResult::LoadExecutable() to properly initialize the state of the name uniquer in the IR emitter context on this path to execution. This makes the XLA kernel cache compatible with the JAX module cache (example test: `bazel test --test_env=XLA_FLAGS="--xla_gpu_enable_llvm_module_compilation_parallelism --xla_gpu_kernel_cache_file=/dev/shm/xla.kernel.cache" tests/compilation_cache_test_gpu`). Copybara import of the project: -- b51fbcac5c7f172d06eb9de79770564f2e2c1250 by Ilia Sergachev : [GPU] Fix kernel cache for loaded executables. This makes the XLA kernel cache compatible with JAX module cache. Merging this change closes #15998 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15998 from openxla:fix_kernel_cache b51fbcac5c7f172d06eb9de79770564f2e2c1250 PiperOrigin-RevId: 663326327 --- xla/service/gpu/BUILD | 9 ++- xla/service/gpu/compile_module_to_llvm_ir.cc | 13 ++-- xla/service/gpu/compile_module_to_llvm_ir.h | 4 + xla/service/gpu/gpu_compiler.cc | 10 +++ xla/service/gpu/gpu_compiler_test.cc | 80 ++++++++++++++++++++ 5 files changed, 106 insertions(+), 10 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 102aa0ace5de79..273934b77f7039 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1281,10 +1281,8 @@ cc_library( "//xla/service:hlo_ordering", "//xla/service:hlo_proto_cc", "//xla/service:logical_buffer", - "//xla/service/gpu/runtime:conditional_thunk", "//xla/service/gpu/runtime:sequential_thunk", "//xla/service/gpu/runtime:thunk", - "//xla/service/gpu/runtime:while_thunk", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor/rocm:rocm_platform_id", @@ -1611,9 +1609,13 @@ xla_test( ":metrics", "//xla:autotune_results_proto_cc", "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "//xla/service:compiler", "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", @@ -1621,8 +1623,11 @@ xla_test( "//xla/service:xla_debug_info_manager", "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", diff --git a/xla/service/gpu/compile_module_to_llvm_ir.cc b/xla/service/gpu/compile_module_to_llvm_ir.cc index c21784b1b3dda8..abc288e4a96d20 100644 --- a/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" @@ -53,10 +54,6 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_unnested.h" #include "xla/service/gpu/metrics.h" -#include "xla/service/gpu/runtime/conditional_thunk.h" -#include "xla/service/gpu/runtime/sequential_thunk.h" -#include "xla/service/gpu/runtime/thunk.h" -#include "xla/service/gpu/runtime/while_thunk.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_ordering.h" #include "xla/service/logical_buffer.h" @@ -102,8 +99,10 @@ void RemoveUnusedAndUninitializedGlobals( } } -static absl::Status LoadCache(IrEmitterContext& ir_emitter_context, - absl::string_view cache_file_path) { +} // namespace + +absl::Status LoadCache(IrEmitterContext& ir_emitter_context, + absl::string_view cache_file_path) { std::string resolved_path; if (!tsl::io::ResolveTestPrefixes(cache_file_path, resolved_path)) { return FailedPrecondition("File path can not be resolved: %s", @@ -131,8 +130,6 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context, return absl::OkStatus(); } -} // namespace - absl::StatusOr CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, diff --git a/xla/service/gpu/compile_module_to_llvm_ir.h b/xla/service/gpu/compile_module_to_llvm_ir.h index d7005f879c3994..a451af5a149fad 100644 --- a/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/xla/service/gpu/compile_module_to_llvm_ir.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/execution_stream_assignment.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/hlo.pb.h" @@ -66,6 +67,9 @@ struct CompileModuleResults { void ForAllThunks(const std::function& fn, ThunkSequence* thunk_sequence); +absl::Status LoadCache(IrEmitterContext& ir_emitter_context, + absl::string_view cache_file_path); + absl::StatusOr CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index eb538d1e17c666..be7ba5c92c3da9 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -410,6 +410,16 @@ GpuThunkAotCompilationResult::LoadExecutable( platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(), /*llvm_module_constants=*/nullptr, /*emit_kernels=*/false); + + absl::string_view cache_file_path = + hlo_module->config().debug_options().xla_gpu_kernel_cache_file(); + if (!cache_file_path.empty() && + hlo_module->config() + .debug_options() + .xla_gpu_enable_llvm_module_compilation_parallelism()) { + TF_RETURN_IF_ERROR(LoadCache(ir_emitter_context, cache_file_path)); + } + auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); TF_RETURN_IF_ERROR( ir_emitter->EmitHloComputation(hlo_module->entry_computation())); diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 93057c595976dd..9e07fb6e38fe66 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -33,8 +34,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/primitive_util.h" +#include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_hlo_schedule.h" @@ -44,8 +49,11 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/xla_debug_info_manager.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" @@ -804,6 +812,78 @@ TEST_F(KernelCacheTest, AllKernelsAreCachedBecauseSplitModuleUsesRoundRobin) { EXPECT_EQ(CacheEntryCount(), 4); } +TEST_F(KernelCacheTest, CachingWorksWithLoadedExecutables) { + const std::string kHloAdd1 = R"( +add1 { + p = s32[] parameter(0) + c = s32[] constant(1) + ROOT a = s32[] add(p, c) +} + +ENTRY e { + p = s32[] parameter(0) + ROOT r = s32[] fusion(p), kind=kLoop, calls=add1 +})"; + + const std::string kHloAdd2 = R"( +add2 { + p = s32[] parameter(0) + c = s32[] constant(2) + ROOT a = s32[] add(p, c) +} + +ENTRY e { + p = s32[] parameter(0) + ROOT r = s32[] fusion(p), kind=kLoop, calls=add2 +})"; + + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + se::PlatformManager::PlatformWithName("cuda")); + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, + platform->ExecutorForDevice(0)); + + Compiler* compiler = backend().compiler(); + AotCompilationOptions aot_options(compiler->PlatformId()); + aot_options.set_executor(stream_exec); + + auto test = [this, &compiler, &aot_options](absl::string_view hlo, int input, + int expected_result) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto module_group = std::make_unique(std::move(module)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector> aot_results, + compiler->CompileAheadOfTime(std::move(module_group), aot_options)); + + TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result, + aot_results[0]->SerializeAsString()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr aot_result, + compiler->LoadAotCompilationResult(serialized_aot_result)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + aot_result->LoadExecutable(compiler, aot_options.executor())); + + const xla::Literal literal_input = + xla::LiteralUtil::CreateR0(input); + const xla::Literal literal_expected_result = + xla::LiteralUtil::CreateR0(expected_result); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, + GetHloRunner().value()->ExecuteWithExecutable( + executable.get(), {&literal_input})); + + EXPECT_TRUE(LiteralTestUtil::Equal(result, literal_expected_result)); + }; + + test(kHloAdd1, 1, 2); + test(kHloAdd2, 1, 3); + // The test used to fail on the second execution of the second module when it + // was already cached. + test(kHloAdd2, 1, 3); +} + class KernelCacheTestSingleThreaded : public KernelCacheTest { public: DebugOptions GetDebugOptionsForTest() override {