-
Notifications
You must be signed in to change notification settings - Fork 23
/
podrun
executable file
·44 lines (32 loc) · 1.47 KB
/
podrun
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#!/usr/bin/python3
import argparse
import fabric
import os
# https://stackoverflow.com/a/13786877
def shell_escape(arg):
return "'%s'" % (arg.replace(r"'", r"'\''"),)
def get_ips():
with open(os.path.expanduser('~/podips.txt'), 'r') as f:
return [line.rstrip('\n') for line in f]
def run_command(hosts, command):
with fabric.ThreadingGroup(*hosts) as group:
group.run(command)
def main():
parser = argparse.ArgumentParser(description='A helper script to execute commands on multiple hosts of a TPU pod.')
parser.add_argument('-i', '--include-local', action='store_true', help='include local host (127.0.0.1) in the host list')
parser.add_argument('-c', '--clean-up', action='store_true', help='clean up temporary files generated by previous TPU processes before executing the command')
parser.add_argument('-w', '--cwd', action='store_true', help='run the command in the current working directory, assuming the directory exists on all hosts')
parser.add_argument('command', nargs=argparse.ONE_OR_MORE)
args = parser.parse_args()
hosts = get_ips()
if args.include_local:
hosts.append('127.0.0.1')
command = ' '.join(shell_escape(command) for command in args.command)
if args.clean_up:
command = f'sudo rm -rf /tmp/libtpu_lockfile /tmp/tpu_logs; {command}'
if args.cwd:
cwd = os.getcwd()
command = f'cd {cwd}; {command}'
run_command(hosts, command)
if __name__ == '__main__':
main()