Skip to content

Commit

Permalink
PR #15998: [GPU] Fix kernel cache for loaded executables.
Browse files Browse the repository at this point in the history
Imported from GitHub PR #15998

LoadCache() has to be called also in GpuThunkAotCompilationResult::LoadExecutable() to properly initialize the state of the name uniquer in the IR emitter context on this path to execution.

This makes the XLA kernel cache compatible with the JAX module cache (example test: `bazel test --test_env=XLA_FLAGS="--xla_gpu_enable_llvm_module_compilation_parallelism --xla_gpu_kernel_cache_file=/dev/shm/xla.kernel.cache" tests/compilation_cache_test_gpu`).
Copybara import of the project:

--
b51fbca by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Fix kernel cache for loaded executables.

This makes the XLA kernel cache compatible with JAX module cache.

Merging this change closes #15998

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15998 from openxla:fix_kernel_cache b51fbca
PiperOrigin-RevId: 663326327
  • Loading branch information
sergachev authored and copybara-github committed Aug 16, 2024
1 parent d85d92f commit 4d0ad05
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 10 deletions.
9 changes: 7 additions & 2 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1281,10 +1281,8 @@ cc_library(
"//xla/service:hlo_ordering",
"//xla/service:hlo_proto_cc",
"//xla/service:logical_buffer",
"//xla/service/gpu/runtime:conditional_thunk",
"//xla/service/gpu/runtime:sequential_thunk",
"//xla/service/gpu/runtime:thunk",
"//xla/service/gpu/runtime:while_thunk",
"//xla/stream_executor",
"//xla/stream_executor:device_description",
"//xla/stream_executor/rocm:rocm_platform_id",
Expand Down Expand Up @@ -1611,18 +1609,25 @@ xla_test(
":metrics",
"//xla:autotune_results_proto_cc",
"//xla:error_spec",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/service:compiler",
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/service:xla_debug_info_manager",
"//xla/service/gpu/autotuning:autotuner_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:literal_test_util",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
Expand Down
13 changes: 5 additions & 8 deletions xla/service/gpu/compile_module_to_llvm_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
Expand All @@ -53,10 +54,6 @@ limitations under the License.
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/ir_emitter_unnested.h"
#include "xla/service/gpu/metrics.h"
#include "xla/service/gpu/runtime/conditional_thunk.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/gpu/runtime/while_thunk.h"
#include "xla/service/hlo_dataflow_analysis.h"
#include "xla/service/hlo_ordering.h"
#include "xla/service/logical_buffer.h"
Expand Down Expand Up @@ -102,8 +99,10 @@ void RemoveUnusedAndUninitializedGlobals(
}
}

static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path) {
} // namespace

absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path) {
std::string resolved_path;
if (!tsl::io::ResolveTestPrefixes(cache_file_path, resolved_path)) {
return FailedPrecondition("File path can not be resolved: %s",
Expand Down Expand Up @@ -131,8 +130,6 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
return absl::OkStatus();
}

} // namespace

absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/compile_module_to_llvm_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "xla/service/gpu/executable.pb.h"
#include "xla/service/gpu/execution_stream_assignment.h"
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/hlo.pb.h"
Expand Down Expand Up @@ -66,6 +67,9 @@ struct CompileModuleResults {
void ForAllThunks(const std::function<void(Thunk*)>& fn,
ThunkSequence* thunk_sequence);

absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path);

absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,
Expand Down
10 changes: 10 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,16 @@ GpuThunkAotCompilationResult::LoadExecutable(
platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(),
/*llvm_module_constants=*/nullptr,
/*emit_kernels=*/false);

absl::string_view cache_file_path =
hlo_module->config().debug_options().xla_gpu_kernel_cache_file();
if (!cache_file_path.empty() &&
hlo_module->config()
.debug_options()
.xla_gpu_enable_llvm_module_compilation_parallelism()) {
TF_RETURN_IF_ERROR(LoadCache(ir_emitter_context, cache_file_path));
}

auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context);
TF_RETURN_IF_ERROR(
ir_emitter->EmitHloComputation(hlo_module->entry_computation()));
Expand Down
80 changes: 80 additions & 0 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand All @@ -33,8 +34,12 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
#include "xla/service/compiler.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/gpu_hlo_schedule.h"
Expand All @@ -44,8 +49,11 @@ limitations under the License.
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/service/xla_debug_info_manager.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
Expand Down Expand Up @@ -804,6 +812,78 @@ TEST_F(KernelCacheTest, AllKernelsAreCachedBecauseSplitModuleUsesRoundRobin) {
EXPECT_EQ(CacheEntryCount(), 4);
}

TEST_F(KernelCacheTest, CachingWorksWithLoadedExecutables) {
const std::string kHloAdd1 = R"(
add1 {
p = s32[] parameter(0)
c = s32[] constant(1)
ROOT a = s32[] add(p, c)
}
ENTRY e {
p = s32[] parameter(0)
ROOT r = s32[] fusion(p), kind=kLoop, calls=add1
})";

const std::string kHloAdd2 = R"(
add2 {
p = s32[] parameter(0)
c = s32[] constant(2)
ROOT a = s32[] add(p, c)
}
ENTRY e {
p = s32[] parameter(0)
ROOT r = s32[] fusion(p), kind=kLoop, calls=add2
})";

TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
se::PlatformManager::PlatformWithName("cuda"));
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec,
platform->ExecutorForDevice(0));

Compiler* compiler = backend().compiler();
AotCompilationOptions aot_options(compiler->PlatformId());
aot_options.set_executor(stream_exec);

auto test = [this, &compiler, &aot_options](absl::string_view hlo, int input,
int expected_result) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));
auto module_group = std::make_unique<HloModuleGroup>(std::move(module));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
compiler->CompileAheadOfTime(std::move(module_group), aot_options));

TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result,
aot_results[0]->SerializeAsString());
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AotCompilationResult> aot_result,
compiler->LoadAotCompilationResult(serialized_aot_result));

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Executable> executable,
aot_result->LoadExecutable(compiler, aot_options.executor()));

const xla::Literal literal_input =
xla::LiteralUtil::CreateR0<int32_t>(input);
const xla::Literal literal_expected_result =
xla::LiteralUtil::CreateR0<int32_t>(expected_result);

TF_ASSERT_OK_AND_ASSIGN(Literal result,
GetHloRunner().value()->ExecuteWithExecutable(
executable.get(), {&literal_input}));

EXPECT_TRUE(LiteralTestUtil::Equal(result, literal_expected_result));
};

test(kHloAdd1, 1, 2);
test(kHloAdd2, 1, 3);
// The test used to fail on the second execution of the second module when it
// was already cached.
test(kHloAdd2, 1, 3);
}

class KernelCacheTestSingleThreaded : public KernelCacheTest {
public:
DebugOptions GetDebugOptionsForTest() override {
Expand Down

0 comments on commit 4d0ad05

Please sign in to comment.