@@ -398,22 +398,6 @@ def get_gpu_info():
398398 return _get_amd_gpu_info () if shutil .which ("rocm-smi" ) \
399399 else _get_nvidia_gpu_info ()
400400
401-
402- def _install_tpu_tool ():
403- """Installs the ctpu tool to managing cloud TPUs.
404-
405- Follows the instructions here:
406- https://github.com/tensorflow/tpu/tree/master/tools/ctpu
407- """
408- if not os .path .exists ('ctpu' ):
409- logging .info ('Installing TPU tool' )
410- commands = [
411- 'wget https://dl.google.com/cloud_tpu/ctpu/latest/linux/ctpu' ,
412- 'chmod a+x ctpu' ,
413- ]
414- run_commands (commands )
415-
416-
417401def setup_tpu (parameters ):
418402 """Sets up a TPU with a given set of parameters.
419403
@@ -424,26 +408,23 @@ def setup_tpu(parameters):
424408 True if an error occurs during setup.
425409 """
426410 try :
427- _install_tpu_tool ()
428-
411+ base_cmd = 'gcloud compute tpus execution-groups create'
429412 args = [
413+ '--tpu-only' ,
430414 '--name={}' .format (parameters .get ('name' )),
431415 '--project={}' .format (parameters .get ('project' )),
432416 '--zone={}' .format (parameters .get ('zone' )),
433- '--tpu-size ={}' .format (parameters .get ('size' )),
417+ '--accelerator-type ={}' .format (parameters .get ('size' )),
434418 '--tf-version={}' .format (parameters .get ('version' )),
435- '--tpu-only' ,
436- '-noconf' ,
437419 ]
438- command = './ctpu up {}' .format (' ' .join (args ))
420+ command = '{} {}' .format (base_cmd , ' ' .join (args ))
439421 logging .info ('Setting up TPU: %s' , command )
440422 exit_code , output = run_command (command )
441423 if exit_code != 0 :
442424 logging .error ('Error in setup with output: %s' , output )
443425 return exit_code != 0
444426 except Exception :
445427 logging .error ('Unable to setup TPU' )
446- run_command ('rm -f ctpu' )
447428 sys .exit (1 )
448429
449430
@@ -456,16 +437,15 @@ def cleanup_tpu(parameters):
456437 Returns:
457438 True if an error occurs during cleanup.
458439 """
459- _install_tpu_tool ()
440+
441+ base_cmd = 'gcloud compute tpus execution-groups delete'
460442
461443 args = [
462- '--name= {}' .format (parameters .get ('name' )),
444+ '{}' .format (parameters .get ('name' )),
463445 '--project={}' .format (parameters .get ('project' )),
464446 '--zone={}' .format (parameters .get ('zone' )),
465- '--tpu-only' ,
466- '-noconf' ,
467447 ]
468- command = './ctpu delete {}' .format (' ' .join (args ))
448+ command = '{} {}' .format (base_cmd , ' ' .join (args ))
469449 logging .info ('Cleaning up TPU: %s' , command )
470450 exit_code , output = run_command (command )
471451 if exit_code != 0 :
0 commit comments