forked from pytorch/builder
-
Notifications
You must be signed in to change notification settings - Fork 1
/
build_aarch64_wheel.py
executable file
·598 lines (492 loc) · 25.2 KB
/
build_aarch64_wheel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
#!/usr/bin/env python3
import boto3
import os
import subprocess
import sys
import time
from typing import List, Optional, Tuple, Union
# AMI images for us-east-1, change the following based on your ~/.aws/config
os_amis = {
'ubuntu18_04': "ami-0f2b111fdc1647918", # login_name: ubuntu
'ubuntu20_04': "ami-0ea142bd244023692", # login_name: ubuntu
'redhat8': "ami-0698b90665a2ddcf1", # login_name: ec2-user
}
ubuntu18_04_ami = os_amis['ubuntu18_04']
def compute_keyfile_path(key_name: Optional[str] = None) -> Tuple[str, str]:
if key_name is None:
key_name = os.getenv("AWS_KEY_NAME")
if key_name is None:
return os.getenv("SSH_KEY_PATH", ""), ""
homedir_path = os.path.expanduser("~")
default_path = os.path.join(homedir_path, ".ssh", f"{key_name}.pem")
return os.getenv("SSH_KEY_PATH", default_path), key_name
ec2 = boto3.resource("ec2")
def ec2_get_instances(filter_name, filter_value):
return ec2.instances.filter(Filters=[{'Name': filter_name, 'Values': [filter_value]}])
def ec2_instances_of_type(instance_type='t4g.2xlarge'):
return ec2_get_instances('instance-type', instance_type)
def ec2_instances_by_id(instance_id):
rc = list(ec2_get_instances('instance-id', instance_id))
return rc[0] if len(rc) > 0 else None
def start_instance(key_name, ami=ubuntu18_04_ami, instance_type='t4g.2xlarge'):
inst = ec2.create_instances(ImageId=ami,
InstanceType=instance_type,
SecurityGroups=['ssh-allworld'],
KeyName=key_name,
MinCount=1,
MaxCount=1)[0]
print(f'Create instance {inst.id}')
inst.wait_until_running()
running_inst = ec2_instances_by_id(inst.id)
print(f'Instance started at {running_inst.public_dns_name}')
return running_inst
class RemoteHost:
addr: str
keyfile_path: str
login_name: str
container_id: Optional[str] = None
ami: Optional[str] = None
def __init__(self, addr: str, keyfile_path: str, login_name: str = 'ubuntu'):
self.addr = addr
self.keyfile_path = keyfile_path
self.login_name = login_name
def _gen_ssh_prefix(self) -> List[str]:
return ["ssh", "-o", "StrictHostKeyChecking=no", "-i", self.keyfile_path,
f"{self.login_name}@{self.addr}", "--"]
@staticmethod
def _split_cmd(args: Union[str, List[str]]) -> List[str]:
return args.split() if isinstance(args, str) else args
def run_ssh_cmd(self, args: Union[str, List[str]]) -> None:
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
def check_ssh_output(self, args: Union[str, List[str]]) -> str:
return subprocess.check_output(self._gen_ssh_prefix() + self._split_cmd(args)).decode("utf-8")
def scp_upload_file(self, local_file: str, remote_file: str) -> None:
subprocess.check_call(["scp", "-i", self.keyfile_path, local_file,
f"{self.login_name}@{self.addr}:{remote_file}"])
def scp_download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
if local_file is None:
local_file = "."
subprocess.check_call(["scp", "-i", self.keyfile_path,
f"{self.login_name}@{self.addr}:{remote_file}", local_file])
def start_docker(self, image="quay.io/pypa/manylinux2014_aarch64:latest") -> None:
self.run_ssh_cmd("sudo apt-get install -y docker.io")
self.run_ssh_cmd(f"sudo usermod -a -G docker {self.login_name}")
self.run_ssh_cmd("sudo service docker start")
self.run_ssh_cmd(f"docker pull {image}")
self.container_id = self.check_ssh_output(f"docker run -t -d -w /root {image}").strip()
def using_docker(self) -> bool:
return self.container_id is not None
def run_cmd(self, args: Union[str, List[str]]) -> None:
if not self.using_docker():
return self.run_ssh_cmd(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + ['docker', 'exec', '-i', self.container_id, 'bash']
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE)
p.communicate(input=" ".join(["source .bashrc;"] + self._split_cmd(args)).encode("utf-8"))
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd)
def check_output(self, args: Union[str, List[str]]) -> str:
if not self.using_docker():
return self.check_ssh_output(args)
assert self.container_id is not None
docker_cmd = self._gen_ssh_prefix() + ['docker', 'exec', '-i', self.container_id, 'bash']
p = subprocess.Popen(docker_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
(out, err) = p.communicate(input=" ".join(["source .bashrc;"] + self._split_cmd(args)).encode("utf-8"))
rc = p.wait()
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd, output=out, stderr=err)
return out.decode("utf-8")
def upload_file(self, local_file: str, remote_file: str) -> None:
if not self.using_docker():
return self.scp_upload_file(local_file, remote_file)
tmp_file = os.path.join("/tmp", os.path.basename(local_file))
self.scp_upload_file(local_file, tmp_file)
self.run_ssh_cmd(["docker", "cp", tmp_file, f"{self.container_id}:/root/{remote_file}"])
self.run_ssh_cmd(["rm", tmp_file])
def download_file(self, remote_file: str, local_file: Optional[str] = None) -> None:
if not self.using_docker():
return self.scp_download_file(remote_file, local_file)
tmp_file = os.path.join("/tmp", os.path.basename(remote_file))
self.run_ssh_cmd(["docker", "cp", f"{self.container_id}:/root/{remote_file}", tmp_file])
self.scp_download_file(tmp_file, local_file)
self.run_ssh_cmd(["rm", tmp_file])
def list_dir(self, path: str) -> List[str]:
return self.check_output(["ls", "-1", path]).split("\n")
def wait_for_connection(addr, port, timeout=5, attempt_cnt=5):
import socket
for i in range(attempt_cnt):
try:
with socket.create_connection((addr, port), timeout=timeout):
return
except (ConnectionRefusedError, socket.timeout):
if i == attempt_cnt - 1:
raise
time.sleep(timeout)
def update_apt_repo(host: RemoteHost) -> None:
time.sleep(5)
host.run_cmd("sudo systemctl stop apt-daily.service || true")
host.run_cmd("sudo systemctl stop unattended-upgrades.service || true")
host.run_cmd("while systemctl is-active --quiet apt-daily.service; do sleep 1; done")
host.run_cmd("while systemctl is-active --quiet unattended-upgrades.service; do sleep 1; done")
host.run_cmd("sudo apt-get update")
time.sleep(3)
host.run_cmd("sudo apt-get update")
def install_condaforge(host: RemoteHost) -> None:
print('Install conda-forge')
host.run_cmd("curl -OL https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-aarch64.sh")
host.run_cmd("sh -f Miniforge3-Linux-aarch64.sh -b")
if host.using_docker():
host.run_cmd("echo 'PATH=$HOME/miniforge3/bin:$PATH'>>.bashrc")
else:
host.run_cmd(['sed', '-i', '\'/^# If not running interactively.*/i PATH=$HOME/miniforge3/bin:$PATH\'', '.bashrc'])
def build_OpenBLAS(host: RemoteHost, git_clone_flags: str = "") -> None:
print('Building OpenBLAS')
host.run_cmd(f"git clone https://github.com/xianyi/OpenBLAS -b v0.3.15 {git_clone_flags}")
host.run_cmd("pushd OpenBLAS; make USE_OPENMP=1 NO_SHARED=1 -j8; sudo make USE_OPENMP=1 NO_SHARED=1 install; popd")
def build_FFTW(host: RemoteHost, git_clone_flags: str = "") -> None:
print("Building FFTW3")
host.run_cmd("sudo apt-get install -y ocaml ocamlbuild autoconf automake indent libtool fig2dev texinfo")
# TODO: fix a version to build
# TODO: consider adding flags --host=arm-linux-gnueabi --enable-single --enable-neon CC=arm-linux-gnueabi-gcc -march=armv7-a -mfloat-abi=softfp
host.run_cmd(f"git clone https://github.com/FFTW/fftw3 {git_clone_flags}")
host.run_cmd("pushd fftw3; sh bootstrap.sh; make -j8; sudo make install; popd")
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
host.run_cmd("pip3 install auditwheel")
host.run_cmd("conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf")
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as tmp:
tmp.write(embed_library_script.encode('utf-8'))
tmp.flush()
host.upload_file(tmp.name, "embed_library.py")
print('Embedding libgomp into wheel')
if host.using_docker():
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
else:
host.run_cmd(f"python3 embed_library.py {wheel_name}")
def build_torchvision(host: RemoteHost, *,
branch: str = "master",
use_conda: bool = True,
git_clone_flags: str) -> str:
print('Checking out TorchVision repo')
if branch.startswith("v1.7.1"):
host.run_cmd(f"git clone https://github.com/pytorch/vision -b v0.8.2-rc2 {git_clone_flags}")
elif branch.startswith("v1.8.0"):
host.run_cmd(f"git clone https://github.com/pytorch/vision -b v0.9.0-rc3 {git_clone_flags}")
elif branch.startswith("v1.8.1"):
host.run_cmd(f"git clone https://github.com/pytorch/vision -b v0.9.1-rc1 {git_clone_flags}")
elif branch.startswith("v1.9.0"):
host.run_cmd(f"git clone https://github.com/pytorch/vision -b v0.10.0-rc1 {git_clone_flags}")
else:
host.run_cmd(f"git clone https://github.com/pytorch/vision {git_clone_flags}")
print('Building TorchVision wheel')
build_vars = ""
if branch == 'nightly':
version = host.check_output(["if [ -f vision/version.txt ]; then cat vision/version.txt; fi"]).strip()
if len(version) == 0:
# In older revisions, version was embedded in setup.py
version = host.check_output(["grep", "\"version = '\"", "vision/setup.py"]).strip().split("'")[1][:-2]
build_date = host.check_output("cd pytorch ; git log --pretty=format:%s -1").strip().split()[0].replace("-", "")
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
if branch.startswith("v1.7.1"):
build_vars += "BUILD_VERSION=0.8.2"
elif branch.startswith("v1.8.0"):
build_vars += "BUILD_VERSION=0.9.0"
elif branch.startswith("v1.8.1"):
build_vars += "BUILD_VERSION=0.9.1"
elif branch.startswith("v1.9.0"):
build_vars += "BUILD_VERSION=0.10.0"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd vision; {build_vars} python3 setup.py bdist_wheel")
vision_wheel_name = host.list_dir("vision/dist")[0]
embed_libgomp(host, use_conda, os.path.join('vision', 'dist', vision_wheel_name))
print('Copying TorchVision wheel')
host.download_file(os.path.join('vision', 'dist', vision_wheel_name))
return vision_wheel_name
def build_torchtext(host: RemoteHost, *,
branch: str = "master",
use_conda: bool = True,
git_clone_flags: str = "") -> str:
print('Checking out TorchText repo')
git_clone_flags += " --recurse-submodules"
if branch.startswith("v1.9.0"):
host.run_cmd(f"git clone https://github.com/pytorch/text -b v0.10.0-rc1 {git_clone_flags}")
else:
host.run_cmd(f"git clone https://github.com/pytorch/text {git_clone_flags}")
print('Building TorchText wheel')
build_vars = ""
if branch == 'nightly':
version = host.check_output(["if [ -f text/version.txt ]; then cat text/version.txt; fi"]).strip()
build_date = host.check_output("cd pytorch ; git log --pretty=format:%s -1").strip().split()[0].replace("-", "")
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
if branch.startswith("v1.9.0"):
build_vars += "BUILD_VERSION=0.10.0"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd text; {build_vars} python3 setup.py bdist_wheel")
wheel_name = host.list_dir("text/dist")[0]
embed_libgomp(host, use_conda, os.path.join('text', 'dist', wheel_name))
print('Copying TorchText wheel')
host.download_file(os.path.join('text', 'dist', wheel_name))
return wheel_name
def build_torchaudio(host: RemoteHost, *,
branch: str = "master",
use_conda: bool = True,
git_clone_flags: str = "") -> str:
print('Checking out TorchAudio repo')
git_clone_flags += " --recurse-submodules"
if branch.startswith("v1.9.0"):
host.run_cmd(f"git clone https://github.com/pytorch/audio -b v0.9.0-rc2 {git_clone_flags}")
else:
host.run_cmd(f"git clone https://github.com/pytorch/audio {git_clone_flags}")
print('Building TorchText wheel')
build_vars = ""
if branch == 'nightly':
version = host.check_output(["grep", "\"version = '\"", "audio/setup.py"]).strip().split("'")[1][:-2]
build_date = host.check_output("cd pytorch ; git log --pretty=format:%s -1").strip().split()[0].replace("-", "")
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
if branch.startswith("v1.9.0"):
build_vars += "BUILD_VERSION=0.9.0"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd audio; {build_vars} python3 setup.py bdist_wheel")
wheel_name = host.list_dir("audio/dist")[0]
embed_libgomp(host, use_conda, os.path.join('audio', 'dist', wheel_name))
print('Copying TorchAudio wheel')
host.download_file(os.path.join('audio', 'dist', wheel_name))
return wheel_name
def start_build(host: RemoteHost, *,
branch="master",
compiler="gcc-8",
use_conda=True,
python_version="3.8",
keep_running=False,
shallow_clone=True) -> Tuple[str, str]:
if host.using_docker() and not use_conda:
print("Auto-selecting conda option for docker images")
use_conda = True
if use_conda:
install_condaforge(host)
host.run_cmd(f"conda install -y python={python_version} numpy pyyaml")
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
print('Configuring the system')
if not host.using_docker():
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
else:
host.run_cmd("yum install -y sudo")
host.run_cmd("conda install -y ninja")
if not use_conda:
host.run_cmd("sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip")
host.run_cmd("pip3 install dataclasses typing-extensions")
# Install and switch to gcc-8 on Ubuntu-18.04
if not host.using_docker() and host.ami == ubuntu18_04_ami and compiler == 'gcc-8':
host.run_cmd("sudo apt-get install -y g++-8 gfortran-8")
host.run_cmd("sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100")
host.run_cmd("sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 100")
host.run_cmd("sudo update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-8 100")
if not use_conda:
print("Installing Cython + numpy from PyPy")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
build_OpenBLAS(host, git_clone_flags)
# build_FFTW(host, git_clone_flags)
if host.using_docker():
print("Move libgfortant.a into a standard location")
# HACK: pypa gforntran.a is compiled without PIC, which leads to the following error
# libgfortran.a(error.o)(.text._gfortrani_st_printf+0x34): unresolvable R_AARCH64_ADR_PREL_PG_HI21 relocation against symbol `__stack_chk_guard@@GLIBC_2.17'
# Workaround by copying gfortran library from the host
host.run_ssh_cmd("sudo apt-get install -y gfortran-8")
host.run_cmd("mkdir -p /usr/lib/gcc/aarch64-linux-gnu/8")
host.run_ssh_cmd(["docker", "cp", "/usr/lib/gcc/aarch64-linux-gnu/8/libgfortran.a",
f"{host.container_id}:/opt/rh/devtoolset-9/root/usr/lib/gcc/aarch64-redhat-linux/9/"
])
print('Checking out PyTorch repo')
host.run_cmd(f"git clone --recurse-submodules -b {branch} https://github.com/pytorch/pytorch {git_clone_flags}")
print('Building PyTorch wheel')
build_vars = ""
if branch == 'nightly':
build_date = host.check_output("cd pytorch ; git log --pretty=format:%s -1").strip().split()[0].replace("-", "")
version = host.check_output("cat pytorch/version.txt").strip()[:-2]
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
if branch.startswith("v1."):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1:branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
host.run_cmd(f"cd pytorch ; {build_vars} python3 setup.py bdist_wheel")
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
embed_libgomp(host, use_conda, os.path.join('pytorch', 'dist', pytorch_wheel_name))
print('Copying the wheel')
host.download_file(os.path.join('pytorch', 'dist', pytorch_wheel_name))
print('Installing PyTorch wheel')
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
vision_wheel_name = build_torchvision(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
build_torchaudio(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
build_torchtext(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
if keep_running:
return pytorch_wheel_name, vision_wheel_name
print(f'Waiting for instance {inst.id} to terminate')
inst.terminate()
inst.wait_until_terminated()
return pytorch_wheel_name, vision_wheel_name
embed_library_script = """
#!/usr/bin/env python3
from auditwheel.patcher import Patchelf
from auditwheel.wheeltools import InWheelCtx
from auditwheel.elfutils import elf_file_filter
from auditwheel.repair import copylib
from auditwheel.lddtree import lddtree
from subprocess import check_call
import os
import shutil
import sys
from tempfile import TemporaryDirectory
def replace_tag(filename):
with open(filename, 'r') as f:
lines = f.read().split("\\n")
for i,line in enumerate(lines):
if not line.startswith("Tag: "):
continue
lines[i] = line.replace("-linux_", "-manylinux2014_")
print(f'Updated tag from {line} to {lines[i]}')
with open(filename, 'w') as f:
f.write("\\n".join(lines))
class AlignedPatchelf(Patchelf):
def set_soname(self, file_name: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--set-soname', new_soname, file_name])
def replace_needed(self, file_name: str, soname: str, new_soname: str) -> None:
check_call(['patchelf', '--page-size', '65536', '--replace-needed', soname, new_soname, file_name])
def embed_library(whl_path, lib_soname, update_tag=False):
patcher = AlignedPatchelf()
out_dir = TemporaryDirectory()
whl_name = os.path.basename(whl_path)
tmp_whl_name = os.path.join(out_dir.name, whl_name)
with InWheelCtx(whl_path) as ctx:
torchlib_path = os.path.join(ctx._tmpdir.name, 'torch', 'lib')
ctx.out_wheel=tmp_whl_name
new_lib_path, new_lib_soname = None, None
for filename, elf in elf_file_filter(ctx.iter_files()):
if not filename.startswith('torch/lib'):
continue
libtree = lddtree(filename)
if lib_soname not in libtree['needed']:
continue
lib_path = libtree['libs'][lib_soname]['path']
if lib_path is None:
print(f"Can't embed {lib_soname} as it could not be found")
break
if lib_path.startswith(torchlib_path):
continue
if new_lib_path is None:
new_lib_soname, new_lib_path = copylib(lib_path, torchlib_path, patcher)
patcher.replace_needed(filename, lib_soname, new_lib_soname)
print(f'Replacing {lib_soname} with {new_lib_soname} for {filename}')
if update_tag:
# Add manylinux2014 tag
for filename in ctx.iter_files():
if os.path.basename(filename) != 'WHEEL':
continue
replace_tag(filename)
shutil.move(tmp_whl_name, whl_path)
if __name__ == '__main__':
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
"""
def run_tests(host: RemoteHost, whl: str, branch='master') -> None:
print('Configuring the system')
update_apt_repo(host)
host.run_cmd("sudo apt-get install -y python3-pip git")
host.run_cmd("sudo pip3 install Cython")
host.run_cmd("sudo pip3 install numpy")
host.upload_file(whl, ".")
host.run_cmd(f"sudo pip3 install {whl}")
host.run_cmd("python3 -c 'import torch;print(torch.rand((3,3))'")
host.run_cmd(f"git clone -b {branch} https://github.com/pytorch/pytorch")
host.run_cmd("cd pytorch/test; python3 test_torch.py -v")
def get_instance_name(instance) -> Optional[str]:
if instance.tags is None:
return None
for tag in instance.tags:
if tag['Key'] == 'Name':
return tag['Value']
return None
def list_instances(instance_type: str) -> None:
print(f"All instances of type {instance_type}")
for instance in ec2_instances_of_type(instance_type):
print(f"{instance.id} {get_instance_name(instance)} {instance.public_dns_name} {instance.state['Name']}")
def terminate_instances(instance_type: str) -> None:
print(f"Terminating all instances of type {instance_type}")
instances = list(ec2_instances_of_type(instance_type))
for instance in instances:
print(f"Terminating {instance.id}")
instance.terminate()
print("Waiting for termination to complete")
for instance in instances:
instance.wait_until_terminated()
def parse_arguments():
from argparse import ArgumentParser
parser = ArgumentParser("Builid and test AARCH64 wheels using EC2")
parser.add_argument("--key-name", type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--build-only", action="store_true")
parser.add_argument("--test-only", type=str)
parser.add_argument("--os", type=str, choices=list(os_amis.keys()), default='ubuntu18_04')
parser.add_argument("--python-version", type=str, choices=['3.6', '3.7', '3.8', '3.9'], default=None)
parser.add_argument("--alloc-instance", action="store_true")
parser.add_argument("--list-instances", action="store_true")
parser.add_argument("--keep-running", action="store_true")
parser.add_argument("--terminate-instances", action="store_true")
parser.add_argument("--instance-type", type=str, default="t4g.2xlarge")
parser.add_argument("--branch", type=str, default="master")
parser.add_argument("--use-docker", action="store_true")
parser.add_argument("--compiler", type=str, choices=['gcc-7', 'gcc-8', 'gcc-9', 'clang'], default="gcc-8")
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
ami = os_amis[args.os]
keyfile_path, key_name = compute_keyfile_path(args.key_name)
if args.list_instances:
list_instances(args.instance_type)
sys.exit(0)
if args.terminate_instances:
terminate_instances(args.instance_type)
sys.exit(0)
if len(key_name) == 0:
raise Exception("""
Cannot start build without key_name, please specify
--key-name argument or AWS_KEY_NAME environment variable.""")
if len(keyfile_path) == 0 or not os.path.exists(keyfile_path):
raise Exception(f"""
Cannot find keyfile with name: [{key_name}] in path: [{keyfile_path}], please
check `~/.ssh/` folder or manually set SSH_KEY_PATH environment variable.""")
# Starting the instance
inst = start_instance(key_name, ami=ami)
instance_name = f'{args.key_name}-{args.os}'
if args.python_version is not None:
instance_name += f'-py{args.python_version}'
inst.create_tags(DryRun=False, Tags=[{
'Key': 'Name',
'Value': instance_name,
}])
addr = inst.public_dns_name
wait_for_connection(addr, 22)
host = RemoteHost(addr, keyfile_path)
host.ami = ami
if args.use_docker:
update_apt_repo(host)
host.start_docker()
if args.test_only:
run_tests(host, args.test_only)
sys.exit(0)
if args.alloc_instance:
if args.python_version is None:
sys.exit(0)
install_condaforge(host)
host.run_cmd(f"conda install -y python={args.python_version} numpy pyyaml")
sys.exit(0)
python_version = args.python_version if args.python_version is not None else '3.8'
start_build(host,
branch=args.branch,
compiler=args.compiler,
python_version=python_version,
keep_running=args.keep_running)