Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Fix kernel cache for loaded executables. #15998

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1280,10 +1280,6 @@ cc_library(
"//xla/service:hlo_ordering",
"//xla/service:hlo_proto_cc",
"//xla/service:logical_buffer",
"//xla/service/gpu/runtime:conditional_thunk",
"//xla/service/gpu/runtime:sequential_thunk",
"//xla/service/gpu/runtime:thunk",
"//xla/service/gpu/runtime:while_thunk",
"//xla/stream_executor",
"//xla/stream_executor:device_description",
"//xla/stream_executor/rocm:rocm_platform_id",
Expand Down Expand Up @@ -1618,6 +1614,7 @@ xla_test(
"//xla/service:xla_debug_info_manager",
"//xla/service/gpu/autotuning:autotuner_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:platform_manager",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
Expand Down
13 changes: 4 additions & 9 deletions xla/service/gpu/compile_module_to_llvm_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,8 @@ limitations under the License.
#include "xla/service/gpu/gpu_constants.h"
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/gpu_memory_space_assignment.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/ir_emitter_unnested.h"
#include "xla/service/gpu/metrics.h"
#include "xla/service/gpu/runtime/conditional_thunk.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/gpu/runtime/while_thunk.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should those cleanups be in this commit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's tempting to include such trivial small cleanups while changing the files for other reasons (also because for us, external developers, submitting changes is a slower process). I'll try to separate them in the future.

#include "xla/service/hlo_dataflow_analysis.h"
#include "xla/service/hlo_ordering.h"
#include "xla/service/logical_buffer.h"
Expand Down Expand Up @@ -102,8 +97,10 @@ void RemoveUnusedAndUninitializedGlobals(
}
}

static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path) {
} // namespace

absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path) {
std::string resolved_path;
if (!tsl::io::ResolveTestPrefixes(cache_file_path, resolved_path)) {
return FailedPrecondition("File path can not be resolved: %s",
Expand Down Expand Up @@ -131,8 +128,6 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
return absl::OkStatus();
}

} // namespace

absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/compile_module_to_llvm_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "xla/service/gpu/executable.pb.h"
#include "xla/service/gpu/execution_stream_assignment.h"
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/hlo.pb.h"
Expand Down Expand Up @@ -66,6 +67,9 @@ struct CompileModuleResults {
void ForAllThunks(const std::function<void(Thunk*)>& fn,
ThunkSequence* thunk_sequence);

absl::Status LoadCache(IrEmitterContext& ir_emitter_context,
absl::string_view cache_file_path);

absl::StatusOr<CompileModuleResults> CompileModuleToLlvmIr(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,
Expand Down
10 changes: 10 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,16 @@ GpuThunkAotCompilationResult::LoadExecutable(
platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(),
/*llvm_module_constants=*/nullptr,
/*emit_kernels=*/false);

absl::string_view cache_file_path =
hlo_module->config().debug_options().xla_gpu_kernel_cache_file();
if (!cache_file_path.empty() &&
hlo_module->config()
.debug_options()
.xla_gpu_enable_llvm_module_compilation_parallelism()) {
TF_RETURN_IF_ERROR(LoadCache(ir_emitter_context, cache_file_path));
}

auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context);
TF_RETURN_IF_ERROR(
ir_emitter->EmitHloComputation(hlo_module->entry_computation()));
Expand Down
73 changes: 73 additions & 0 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/service/xla_debug_info_manager.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
Expand Down Expand Up @@ -804,6 +805,78 @@ TEST_F(KernelCacheTest, AllKernelsAreCachedBecauseSplitModuleUsesRoundRobin) {
EXPECT_EQ(CacheEntryCount(), 4);
}

TEST_F(KernelCacheTest, CachingWorksWithLoadedExecutables) {
const std::string kHloAdd1 = R"(
add1 {
p = s32[] parameter(0)
c = s32[] constant(1)
ROOT a = s32[] add(p, c)
}

ENTRY e {
p = s32[] parameter(0)
ROOT r = s32[] fusion(p), kind=kLoop, calls=add1
})";

const std::string kHloAdd2 = R"(
add2 {
p = s32[] parameter(0)
c = s32[] constant(2)
ROOT a = s32[] add(p, c)
}

ENTRY e {
p = s32[] parameter(0)
ROOT r = s32[] fusion(p), kind=kLoop, calls=add2
})";

TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
se::PlatformManager::PlatformWithName("cuda"));
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec,
platform->ExecutorForDevice(0));

Compiler* compiler = backend().compiler();
AotCompilationOptions aot_options(compiler->PlatformId());
aot_options.set_executor(stream_exec);

auto test = [this, &compiler, &aot_options](absl::string_view hlo, int input,
int expected_result) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));
auto module_group = std::make_unique<HloModuleGroup>(std::move(module));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
compiler->CompileAheadOfTime(std::move(module_group), aot_options));

TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result,
aot_results[0]->SerializeAsString());
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AotCompilationResult> aot_result,
compiler->LoadAotCompilationResult(serialized_aot_result));

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Executable> executable,
aot_result->LoadExecutable(compiler, aot_options.executor()));

const xla::Literal literal_input =
xla::LiteralUtil::CreateR0<int32_t>(input);
const xla::Literal literal_expected_result =
xla::LiteralUtil::CreateR0<int32_t>(expected_result);

TF_ASSERT_OK_AND_ASSIGN(Literal result,
GetHloRunner().value()->ExecuteWithExecutable(
executable.get(), {&literal_input}));

EXPECT_TRUE(LiteralTestUtil::Equal(result, literal_expected_result));
};

test(kHloAdd1, 1, 2);
test(kHloAdd2, 1, 3);
// The test used to fail on the second execution of the second module when it
// was already cached.
test(kHloAdd2, 1, 3);
}

class KernelCacheTestSingleThreaded : public KernelCacheTest {
public:
DebugOptions GetDebugOptionsForTest() override {
Expand Down
Loading