@@ -83,6 +83,7 @@ int main(int argc, char** argv) {
8383 tensorflow::serving::main::Server::Options options;
8484 bool display_version = false ;
8585 bool xla_cpu_compilation_enabled = false ;
86+ bool xla_gpu_compilation_enabled = false ;
8687 std::vector<tensorflow::Flag> flag_list = {
8788 tensorflow::Flag (" port" , &options.grpc_port ,
8889 " TCP port to listen on for gRPC/HTTP API. Disabled if "
@@ -290,6 +291,10 @@ int main(int argc, char** argv) {
290291 " Enable XLA:CPU JIT (default is disabled). With XLA:CPU JIT "
291292 " disabled, models utilizing this feature will return bad Status "
292293 " on first compilation request." ),
294+ tensorflow::Flag (
295+ " xla_gpu_compilation_enabled" , &xla_gpu_compilation_enabled,
296+ " EXPERIMENTAL; CAN BE REMOVED ANYTIME! "
297+ " Enable both XLA:CPU JIT and XLA:GPU JIT (default is disabled)." ),
293298 tensorflow::Flag (" enable_profiler" , &options.enable_profiler ,
294299 " Enable profiler service." ),
295300 tensorflow::Flag (" thread_pool_factory_config_file" ,
@@ -325,7 +330,7 @@ int main(int argc, char** argv) {
325330 std::cout << " unknown argument: " << argv[1 ] << " \n " << usage;
326331 }
327332
328- if (!xla_cpu_compilation_enabled) {
333+ if (!xla_cpu_compilation_enabled && !xla_gpu_compilation_enabled ) {
329334 tensorflow::DisableXlaCompilation ();
330335 }
331336
0 commit comments