Skip to content

Commit

Permalink
PR #15998: [GPU] Fix kernel cache for loaded executables.
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#15998

LoadCache() has to be called also in GpuThunkAotCompilationResult::LoadExecutable() to properly initialize the state of the name uniquer in the IR emitter context on this path to execution.

This makes the XLA kernel cache compatible with the JAX module cache (example test: `bazel test --test_env=XLA_FLAGS="--xla_gpu_enable_llvm_module_compilation_parallelism --xla_gpu_kernel_cache_file=/dev/shm/xla.kernel.cache" tests/compilation_cache_test_gpu`).
Copybara import of the project:

--
b51fbcac5c7f172d06eb9de79770564f2e2c1250 by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Fix kernel cache for loaded executables.

This makes the XLA kernel cache compatible with JAX module cache.

Merging this change closes #15998

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15998 from openxla:fix_kernel_cache b51fbcac5c7f172d06eb9de79770564f2e2c1250
PiperOrigin-RevId: 663326327
  • Loading branch information
sergachev authored and tensorflower-gardener committed Aug 16, 2024
1 parent 7f0ef5c commit bbe3d16
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 14 deletions.
13 changes: 10 additions & 3 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,7 @@ cc_library(
":metrics",
":runtime_intrinsics",
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand All @@ -1281,29 +1282,28 @@ cc_library(
"//xla/service:hlo_ordering",
"//xla/service:hlo_proto_cc",
"//xla/service:logical_buffer",
"//xla/service/gpu/runtime:conditional_thunk",
"//xla/service/gpu/runtime:sequential_thunk",
"//xla/service/gpu/runtime:thunk",
"//xla/service/gpu/runtime:while_thunk",
"//xla/stream_executor",
"//xla/stream_executor:device_description",
"//xla/stream_executor/rocm:rocm_platform_id",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:TransformUtils",
"@llvm-project//llvm:ir_headers",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@local_tsl//tsl/platform:casts",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:scoped_annotation",
],
)

Expand Down Expand Up @@ -1611,18 +1611,25 @@ xla_test(
":metrics",
"//xla:autotune_results_proto_cc",
"//xla:error_spec",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/service:compiler",
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/service:xla_debug_info_manager",
"//xla/service/gpu/autotuning:autotuner_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:literal_test_util",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
Expand Down
20 changes: 9 additions & 11 deletions third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#include <stdlib.h>

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
Expand All @@ -28,6 +27,8 @@ limitations under the License.

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
Expand All @@ -53,25 +54,22 @@ limitations under the License.
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/ir_emitter_unnested.h"
#include "xla/service/gpu/metrics.h"
#include "xla/service/gpu/runtime/conditional_thunk.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/gpu/runtime/while_thunk.h"
#include "xla/service/hlo_dataflow_analysis.h"
#include "xla/service/hlo_ordering.h"
#include "xla/service/logical_buffer.h"
#include "xla/shape.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/path.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/scoped_annotation.h"

namespace xla::gpu {

Expand Down Expand Up @@ -102,8 +100,10 @@ void RemoveUnusedAndUninitializedGlobals(
}
}

static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path) {
} // namespace

absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path) {
std::string resolved_path;
if (!tsl::io::ResolveTestPrefixes(cache_file_path, resolved_path)) {
return FailedPrecondition("File path can not be resolved: %s",
Expand All @@ -114,7 +114,7 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
TF_RETURN_IF_ERROR(
tsl::ReadFileToString(tsl::Env::Default(), resolved_path, &serialized));
CompilationCacheProto proto;
if (!proto.ParseFromString(std::string(serialized))) {
if (!proto.ParseFromString(serialized)) {
return Internal("Failed to parse serialized CompilationCacheProto.");
}
// Register all cached kernel names with the name uniquer to avoid
Expand All @@ -131,8 +131,6 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
return absl::OkStatus();
}

} // namespace

absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "xla/service/gpu/executable.pb.h"
#include "xla/service/gpu/execution_stream_assignment.h"
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/hlo.pb.h"
Expand Down Expand Up @@ -66,6 +67,9 @@ struct CompileModuleResults {
void ForAllThunks(const std::function<void(Thunk*)>& fn,
ThunkSequence* thunk_sequence);

absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path);

absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,
Expand Down
10 changes: 10 additions & 0 deletions third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,16 @@ GpuThunkAotCompilationResult::LoadExecutable(
platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(),
/*llvm_module_constants=*/nullptr,
/*emit_kernels=*/false);

absl::string_view cache_file_path =
hlo_module->config().debug_options().xla_gpu_kernel_cache_file();
if (!cache_file_path.empty() &&
hlo_module->config()
.debug_options()
.xla_gpu_enable_llvm_module_compilation_parallelism()) {
TF_RETURN_IF_ERROR(LoadCache(ir_emitter_context, cache_file_path));
}

auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context);
TF_RETURN_IF_ERROR(
ir_emitter->EmitHloComputation(hlo_module->entry_computation()));
Expand Down
80 changes: 80 additions & 0 deletions third_party/xla/xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand All @@ -33,8 +34,12 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/primitive_util.h"
#include "xla/service/compiler.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/gpu_hlo_schedule.h"
Expand All @@ -44,8 +49,11 @@ limitations under the License.
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/service/xla_debug_info_manager.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
Expand Down Expand Up @@ -804,6 +812,78 @@ TEST_F(KernelCacheTest, AllKernelsAreCachedBecauseSplitModuleUsesRoundRobin) {
EXPECT_EQ(CacheEntryCount(), 4);
}

TEST_F(KernelCacheTest, CachingWorksWithLoadedExecutables) {
const std::string kHloAdd1 = R"(
add1 {
p = s32[] parameter(0)
c = s32[] constant(1)
ROOT a = s32[] add(p, c)
}
ENTRY e {
p = s32[] parameter(0)
ROOT r = s32[] fusion(p), kind=kLoop, calls=add1
})";

const std::string kHloAdd2 = R"(
add2 {
p = s32[] parameter(0)
c = s32[] constant(2)
ROOT a = s32[] add(p, c)
}
ENTRY e {
p = s32[] parameter(0)
ROOT r = s32[] fusion(p), kind=kLoop, calls=add2
})";

TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
se::PlatformManager::PlatformWithName("cuda"));
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec,
platform->ExecutorForDevice(0));

Compiler* compiler = backend().compiler();
AotCompilationOptions aot_options(compiler->PlatformId());
aot_options.set_executor(stream_exec);

auto test = [this, &compiler, &aot_options](absl::string_view hlo, int input,
int expected_result) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));
auto module_group = std::make_unique<HloModuleGroup>(std::move(module));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
compiler->CompileAheadOfTime(std::move(module_group), aot_options));

TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result,
aot_results[0]->SerializeAsString());
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AotCompilationResult> aot_result,
compiler->LoadAotCompilationResult(serialized_aot_result));

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Executable> executable,
aot_result->LoadExecutable(compiler, aot_options.executor()));

const xla::Literal literal_input =
xla::LiteralUtil::CreateR0<int32_t>(input);
const xla::Literal literal_expected_result =
xla::LiteralUtil::CreateR0<int32_t>(expected_result);

TF_ASSERT_OK_AND_ASSIGN(Literal result,
GetHloRunner().value()->ExecuteWithExecutable(
executable.get(), {&literal_input}));

EXPECT_TRUE(LiteralTestUtil::Equal(result, literal_expected_result));
};

test(kHloAdd1, 1, 2);
test(kHloAdd2, 1, 3);
// The test used to fail on the second execution of the second module when it
// was already cached.
test(kHloAdd2, 1, 3);
}

class KernelCacheTestSingleThreaded : public KernelCacheTest {
public:
DebugOptions GetDebugOptionsForTest() override {
Expand Down

0 comments on commit bbe3d16

Please sign in to comment.