Skip to content

Commit bbe3d16

Browse files
sergachevtensorflower-gardener
authored andcommitted
PR #15998: [GPU] Fix kernel cache for loaded executables.
Imported from GitHub PR openxla/xla#15998 LoadCache() has to be called also in GpuThunkAotCompilationResult::LoadExecutable() to properly initialize the state of the name uniquer in the IR emitter context on this path to execution. This makes the XLA kernel cache compatible with the JAX module cache (example test: `bazel test --test_env=XLA_FLAGS="--xla_gpu_enable_llvm_module_compilation_parallelism --xla_gpu_kernel_cache_file=/dev/shm/xla.kernel.cache" tests/compilation_cache_test_gpu`). Copybara import of the project: -- b51fbcac5c7f172d06eb9de79770564f2e2c1250 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Fix kernel cache for loaded executables. This makes the XLA kernel cache compatible with JAX module cache. Merging this change closes #15998 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15998 from openxla:fix_kernel_cache b51fbcac5c7f172d06eb9de79770564f2e2c1250 PiperOrigin-RevId: 663326327
1 parent 7f0ef5c commit bbe3d16

File tree

5 files changed

+113
-14
lines changed

5 files changed

+113
-14
lines changed

third_party/xla/xla/service/gpu/BUILD

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,7 @@ cc_library(
12711271
":metrics",
12721272
":runtime_intrinsics",
12731273
"//xla:shape_util",
1274+
"//xla:status_macros",
12741275
"//xla:util",
12751276
"//xla:xla_data_proto_cc",
12761277
"//xla/hlo/ir:hlo",
@@ -1281,29 +1282,28 @@ cc_library(
12811282
"//xla/service:hlo_ordering",
12821283
"//xla/service:hlo_proto_cc",
12831284
"//xla/service:logical_buffer",
1284-
"//xla/service/gpu/runtime:conditional_thunk",
12851285
"//xla/service/gpu/runtime:sequential_thunk",
12861286
"//xla/service/gpu/runtime:thunk",
1287-
"//xla/service/gpu/runtime:while_thunk",
12881287
"//xla/stream_executor",
12891288
"//xla/stream_executor:device_description",
12901289
"//xla/stream_executor/rocm:rocm_platform_id",
12911290
"@com_google_absl//absl/container:flat_hash_map",
12921291
"@com_google_absl//absl/status",
12931292
"@com_google_absl//absl/status:statusor",
12941293
"@com_google_absl//absl/strings",
1294+
"@com_google_absl//absl/strings:str_format",
12951295
"@llvm-project//llvm:AsmParser",
12961296
"@llvm-project//llvm:TransformUtils",
12971297
"@llvm-project//llvm:ir_headers",
12981298
"@llvm-project//mlir:IR",
12991299
"@llvm-project//mlir:Pass",
13001300
"@llvm-project//mlir:Support",
1301-
"@local_tsl//tsl/platform:casts",
13021301
"@local_tsl//tsl/platform:env",
13031302
"@local_tsl//tsl/platform:errors",
13041303
"@local_tsl//tsl/platform:logging",
13051304
"@local_tsl//tsl/platform:path",
13061305
"@local_tsl//tsl/platform:statusor",
1306+
"@local_tsl//tsl/profiler/lib:scoped_annotation",
13071307
],
13081308
)
13091309

@@ -1611,18 +1611,25 @@ xla_test(
16111611
":metrics",
16121612
"//xla:autotune_results_proto_cc",
16131613
"//xla:error_spec",
1614+
"//xla:literal",
1615+
"//xla:literal_util",
16141616
"//xla:shape_util",
16151617
"//xla:xla_data_proto_cc",
16161618
"//xla/hlo/ir:hlo",
1619+
"//xla/hlo/ir:hlo_module_group",
1620+
"//xla/service:compiler",
16171621
"//xla/service:executable",
16181622
"//xla/service:hlo_module_config",
16191623
"//xla/service:pattern_matcher",
16201624
"//xla/service:pattern_matcher_gmock",
16211625
"//xla/service:xla_debug_info_manager",
16221626
"//xla/service/gpu/autotuning:autotuner_util",
16231627
"//xla/stream_executor:device_description",
1628+
"//xla/stream_executor:platform",
1629+
"//xla/stream_executor:platform_manager",
16241630
"//xla/tests:filecheck",
16251631
"//xla/tests:hlo_test_base",
1632+
"//xla/tests:literal_test_util",
16261633
"//xla/tests:xla_internal_test_main",
16271634
"//xla/tsl/lib/core:status_test_util",
16281635
"@com_google_absl//absl/log",

third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818
#include <stdlib.h>
1919

2020
#include <cstdint>
21-
#include <functional>
2221
#include <memory>
2322
#include <optional>
2423
#include <string>
@@ -28,6 +27,8 @@ limitations under the License.
2827

2928
#include "absl/status/status.h"
3029
#include "absl/strings/str_cat.h"
30+
#include "absl/strings/str_format.h"
31+
#include "absl/strings/string_view.h"
3132
#include "llvm/AsmParser/Parser.h"
3233
#include "llvm/IR/DiagnosticInfo.h"
3334
#include "llvm/IR/DiagnosticPrinter.h"
@@ -53,25 +54,22 @@ limitations under the License.
5354
#include "xla/service/gpu/ir_emitter_context.h"
5455
#include "xla/service/gpu/ir_emitter_unnested.h"
5556
#include "xla/service/gpu/metrics.h"
56-
#include "xla/service/gpu/runtime/conditional_thunk.h"
57-
#include "xla/service/gpu/runtime/sequential_thunk.h"
58-
#include "xla/service/gpu/runtime/thunk.h"
59-
#include "xla/service/gpu/runtime/while_thunk.h"
6057
#include "xla/service/hlo_dataflow_analysis.h"
6158
#include "xla/service/hlo_ordering.h"
6259
#include "xla/service/logical_buffer.h"
6360
#include "xla/shape.h"
61+
#include "xla/status_macros.h"
6462
#include "xla/stream_executor/device_description.h"
6563
#include "xla/stream_executor/platform.h"
6664
#include "xla/stream_executor/rocm/rocm_platform_id.h"
6765
#include "xla/util.h"
6866
#include "xla/xla_data.pb.h"
69-
#include "tsl/platform/casts.h"
7067
#include "tsl/platform/env.h"
7168
#include "tsl/platform/errors.h"
7269
#include "tsl/platform/logging.h"
7370
#include "tsl/platform/path.h"
7471
#include "tsl/platform/statusor.h"
72+
#include "tsl/profiler/lib/scoped_annotation.h"
7573

7674
namespace xla::gpu {
7775

@@ -102,8 +100,10 @@ void RemoveUnusedAndUninitializedGlobals(
102100
}
103101
}
104102

105-
static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
106-
absl::string_view cache_file_path) {
103+
} // namespace
104+
105+
absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
106+
absl::string_view cache_file_path) {
107107
std::string resolved_path;
108108
if (!tsl::io::ResolveTestPrefixes(cache_file_path, resolved_path)) {
109109
return FailedPrecondition("File path can not be resolved: %s",
@@ -114,7 +114,7 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
114114
TF_RETURN_IF_ERROR(
115115
tsl::ReadFileToString(tsl::Env::Default(), resolved_path, &serialized));
116116
CompilationCacheProto proto;
117-
if (!proto.ParseFromString(std::string(serialized))) {
117+
if (!proto.ParseFromString(serialized)) {
118118
return Internal("Failed to parse serialized CompilationCacheProto.");
119119
}
120120
// Register all cached kernel names with the name uniquer to avoid
@@ -131,8 +131,6 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
131131
return absl::OkStatus();
132132
}
133133

134-
} // namespace
135-
136134
absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
137135
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
138136
const std::string& target_triple, const std::string& data_layout,

third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
#include "xla/service/gpu/executable.pb.h"
3232
#include "xla/service/gpu/execution_stream_assignment.h"
3333
#include "xla/service/gpu/gpu_executable.h"
34+
#include "xla/service/gpu/ir_emitter_context.h"
3435
#include "xla/service/gpu/runtime/sequential_thunk.h"
3536
#include "xla/service/gpu/runtime/thunk.h"
3637
#include "xla/service/hlo.pb.h"
@@ -66,6 +67,9 @@ struct CompileModuleResults {
6667
void ForAllThunks(const std::function<void(Thunk*)>& fn,
6768
ThunkSequence* thunk_sequence);
6869

70+
absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
71+
absl::string_view cache_file_path);
72+
6973
absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
7074
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
7175
const std::string& target_triple, const std::string& data_layout,

third_party/xla/xla/service/gpu/gpu_compiler.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,16 @@ GpuThunkAotCompilationResult::LoadExecutable(
410410
platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(),
411411
/*llvm_module_constants=*/nullptr,
412412
/*emit_kernels=*/false);
413+
414+
absl::string_view cache_file_path =
415+
hlo_module->config().debug_options().xla_gpu_kernel_cache_file();
416+
if (!cache_file_path.empty() &&
417+
hlo_module->config()
418+
.debug_options()
419+
.xla_gpu_enable_llvm_module_compilation_parallelism()) {
420+
TF_RETURN_IF_ERROR(LoadCache(ir_emitter_context, cache_file_path));
421+
}
422+
413423
auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context);
414424
TF_RETURN_IF_ERROR(
415425
ir_emitter->EmitHloComputation(hlo_module->entry_computation()));

third_party/xla/xla/service/gpu/gpu_compiler_test.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include <memory>
2121
#include <string>
2222
#include <utility>
23+
#include <vector>
2324

2425
#include <gmock/gmock.h>
2526
#include <gtest/gtest.h>
@@ -33,8 +34,12 @@ limitations under the License.
3334
#include "xla/hlo/ir/hlo_computation.h"
3435
#include "xla/hlo/ir/hlo_instruction.h"
3536
#include "xla/hlo/ir/hlo_module.h"
37+
#include "xla/hlo/ir/hlo_module_group.h"
3638
#include "xla/hlo/ir/hlo_opcode.h"
39+
#include "xla/literal.h"
40+
#include "xla/literal_util.h"
3741
#include "xla/primitive_util.h"
42+
#include "xla/service/compiler.h"
3843
#include "xla/service/executable.h"
3944
#include "xla/service/gpu/autotuning/autotuner_util.h"
4045
#include "xla/service/gpu/gpu_hlo_schedule.h"
@@ -44,8 +49,11 @@ limitations under the License.
4449
#include "xla/service/pattern_matcher_gmock.h"
4550
#include "xla/service/xla_debug_info_manager.h"
4651
#include "xla/stream_executor/device_description.h"
52+
#include "xla/stream_executor/platform.h"
53+
#include "xla/stream_executor/platform_manager.h"
4754
#include "xla/tests/filecheck.h"
4855
#include "xla/tests/hlo_test_base.h"
56+
#include "xla/tests/literal_test_util.h"
4957
#include "xla/tsl/lib/core/status_test_util.h"
5058
#include "xla/xla_data.pb.h"
5159
#include "tsl/platform/casts.h"
@@ -804,6 +812,78 @@ TEST_F(KernelCacheTest, AllKernelsAreCachedBecauseSplitModuleUsesRoundRobin) {
804812
EXPECT_EQ(CacheEntryCount(), 4);
805813
}
806814

815+
TEST_F(KernelCacheTest, CachingWorksWithLoadedExecutables) {
816+
const std::string kHloAdd1 = R"(
817+
add1 {
818+
p = s32[] parameter(0)
819+
c = s32[] constant(1)
820+
ROOT a = s32[] add(p, c)
821+
}
822+
823+
ENTRY e {
824+
p = s32[] parameter(0)
825+
ROOT r = s32[] fusion(p), kind=kLoop, calls=add1
826+
})";
827+
828+
const std::string kHloAdd2 = R"(
829+
add2 {
830+
p = s32[] parameter(0)
831+
c = s32[] constant(2)
832+
ROOT a = s32[] add(p, c)
833+
}
834+
835+
ENTRY e {
836+
p = s32[] parameter(0)
837+
ROOT r = s32[] fusion(p), kind=kLoop, calls=add2
838+
})";
839+
840+
TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
841+
se::PlatformManager::PlatformWithName("cuda"));
842+
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec,
843+
platform->ExecutorForDevice(0));
844+
845+
Compiler* compiler = backend().compiler();
846+
AotCompilationOptions aot_options(compiler->PlatformId());
847+
aot_options.set_executor(stream_exec);
848+
849+
auto test = [this, &compiler, &aot_options](absl::string_view hlo, int input,
850+
int expected_result) {
851+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
852+
ParseAndReturnVerifiedModule(hlo));
853+
auto module_group = std::make_unique<HloModuleGroup>(std::move(module));
854+
TF_ASSERT_OK_AND_ASSIGN(
855+
std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
856+
compiler->CompileAheadOfTime(std::move(module_group), aot_options));
857+
858+
TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result,
859+
aot_results[0]->SerializeAsString());
860+
TF_ASSERT_OK_AND_ASSIGN(
861+
std::unique_ptr<AotCompilationResult> aot_result,
862+
compiler->LoadAotCompilationResult(serialized_aot_result));
863+
864+
TF_ASSERT_OK_AND_ASSIGN(
865+
std::unique_ptr<Executable> executable,
866+
aot_result->LoadExecutable(compiler, aot_options.executor()));
867+
868+
const xla::Literal literal_input =
869+
xla::LiteralUtil::CreateR0<int32_t>(input);
870+
const xla::Literal literal_expected_result =
871+
xla::LiteralUtil::CreateR0<int32_t>(expected_result);
872+
873+
TF_ASSERT_OK_AND_ASSIGN(Literal result,
874+
GetHloRunner().value()->ExecuteWithExecutable(
875+
executable.get(), {&literal_input}));
876+
877+
EXPECT_TRUE(LiteralTestUtil::Equal(result, literal_expected_result));
878+
};
879+
880+
test(kHloAdd1, 1, 2);
881+
test(kHloAdd2, 1, 3);
882+
// The test used to fail on the second execution of the second module when it
883+
// was already cached.
884+
test(kHloAdd2, 1, 3);
885+
}
886+
807887
class KernelCacheTestSingleThreaded : public KernelCacheTest {
808888
public:
809889
DebugOptions GetDebugOptionsForTest() override {

0 commit comments

Comments
 (0)