Skip to content

Commit e5e795f

Browse files
Added capability to use XLA on a GPU.
PiperOrigin-RevId: 682116663
1 parent 22e181d commit e5e795f

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

tensorflow_serving/model_servers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ cc_library(
537537
"@com_github_grpc_grpc//:grpc++_reflection",
538538
"@org_tensorflow//tensorflow/c:c_api",
539539
"@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit",
540+
"@org_tensorflow//tensorflow/compiler/jit:xla_gpu_jit",
540541
"@org_tensorflow//tensorflow/core:lib",
541542
"@org_tensorflow//tensorflow/core/platform/cloud:gcs_file_system",
542543
] + if_google(

tensorflow_serving/model_servers/main.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ int main(int argc, char** argv) {
8383
tensorflow::serving::main::Server::Options options;
8484
bool display_version = false;
8585
bool xla_cpu_compilation_enabled = false;
86+
bool xla_gpu_compilation_enabled = false;
8687
std::vector<tensorflow::Flag> flag_list = {
8788
tensorflow::Flag("port", &options.grpc_port,
8889
"TCP port to listen on for gRPC/HTTP API. Disabled if "
@@ -290,6 +291,10 @@ int main(int argc, char** argv) {
290291
"Enable XLA:CPU JIT (default is disabled). With XLA:CPU JIT "
291292
"disabled, models utilizing this feature will return bad Status "
292293
"on first compilation request."),
294+
tensorflow::Flag(
295+
"xla_gpu_compilation_enabled", &xla_gpu_compilation_enabled,
296+
"EXPERIMENTAL; CAN BE REMOVED ANYTIME! "
297+
"Enable both XLA:CPU JIT and XLA:GPU JIT (default is disabled)."),
293298
tensorflow::Flag("enable_profiler", &options.enable_profiler,
294299
"Enable profiler service."),
295300
tensorflow::Flag("thread_pool_factory_config_file",
@@ -325,7 +330,7 @@ int main(int argc, char** argv) {
325330
std::cout << "unknown argument: " << argv[1] << "\n" << usage;
326331
}
327332

328-
if (!xla_cpu_compilation_enabled) {
333+
if (!xla_cpu_compilation_enabled && !xla_gpu_compilation_enabled) {
329334
tensorflow::DisableXlaCompilation();
330335
}
331336

0 commit comments

Comments
 (0)