Skip to content

Commit

Permalink
support custom script remote task submit
Browse files Browse the repository at this point in the history
  • Loading branch information
jianzfb committed Oct 9, 2024
1 parent 7ce32a2 commit e0c3243
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 164 deletions.
28 changes: 18 additions & 10 deletions antgo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
DEFINE_string('branch', None, '')
DEFINE_string('commit', None, '')
DEFINE_string('image', '', '') # 镜像
DEFINE_string('script', None, '') # 自定义脚本.sh
DEFINE_int('cpu', 0, 'set cpu number') # cpu 数
DEFINE_int('gpu', 0, 'set gpu number') # gpu 数
DEFINE_int('memory', 0, 'set memory size (M)') # 内存大小(单位M)
Expand Down Expand Up @@ -752,6 +753,10 @@ def main():
shutil.copyfile(os.path.join(os.path.dirname(__file__), 'resource', 'templates', 'project.json'), './.project.json')

if args.ssh or args.k8s:
# 远程提交模式,仅支持训练和推断
if action_name not in ['train', 'eval']:
logging.error('Antgo remote task submit mode only support train and eval')
return
if args.root.startswith('ali:'):
# 尝试进行认证,从而保证当前路径下生成认证信息
ali = Aligo()
Expand All @@ -763,8 +768,8 @@ def main():
with open('./.project.json', 'r') as fp:
project_info = json.load(fp)

if action_name in ['train', 'activelearning']:
# train, activelearning
if action_name in ['train']:
# train
# 项目基本信息
project_info['image'] = args.image # 镜像名称
if args.exp not in project_info['exp']:
Expand All @@ -784,7 +789,7 @@ def main():
with open('./.project.json', 'w') as fp:
json.dump(project_info,fp)
else:
# eval, export
# eval
# 匹配实验记录(exp, config)
# (1) root 匹配
# (2) 默认匹配最新实验
Expand Down Expand Up @@ -833,18 +838,21 @@ def main():

# 直接进行任务提交
# step 1.1: 检查提交脚本配置
if args.ssh:
# ssh提交
if args.ssh and args.script is None:
# 基于ssh远程管理
sys_argv_cmd = sys_argv_cmd.replace('--ssh', '')
sys_argv_cmd = sys_argv_cmd.replace(' ', ' ')
sys_argv_cmd = f'antgo {sys_argv_cmd}'

ssh_submit_process_func(time.strftime(f"%Y-%m-%d.%H-%M-%S", time.localtime(now_time)), sys_argv_cmd, 0 if args.gpu_id == '' else len(args.gpu_id.split(',')), args.cpu, args.memory, ip=args.ip, exp=args.exp, check_data=args.data, env=args.version)
else:
# 自定义脚本提交
sys_argv_cmd = sys_argv_cmd.replace(' ', ' ')
sys_argv_cmd = f'antgo {sys_argv_cmd}'
custom_submit_process_func(time.strftime(f"%Y-%m-%d.%H-%M-%S", time.localtime(now_time)), sys_argv_cmd, 0 if args.gpu_id == '' else len(args.gpu_id.split(',')), args.cpu, args.memory, ip=args.ip, exp=args.exp, check_data=args.data)
elif args.ssh and args.script is not None:
# 自定义脚本提交,提交远程机器后的启动脚本,所有启动项提交脚本者负责。环境能力,如暴漏GPU由框架负责
assert(args.image is not None and args.image != '')
ssh_submit2_process_func(time.strftime(f"%Y-%m-%d.%H-%M-%S", time.localtime(now_time)), f'bash {args.script}', args.image, 0 if args.gpu_id == '' else len(args.gpu_id.split(',')), args.cpu, args.memory, ip=args.ip, exp=args.exp)
elif args.k8s:
# TODO,基于k8s远程管理
logging.error('Not support k8s now.')
pass

# 清理临时存储信息
if os.path.exists('./aligo.json'):
Expand Down
2 changes: 1 addition & 1 deletion antgo/pipeline/functional/mixins/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2698,7 +2698,7 @@ def build(self, platform='android/arm64-v8a', output_folder='./deploy', project_

# 创建grpc客户端python代码(用于测试)
if not os.path.exists(os.path.join(output_folder, f'grpc_client.py')):
grpc_client_code_template_file = './templates/grpc_client_code.py'
grpc_client_code_template_file = './templates/grpc_client_code'
if call_mode == 'callback' or call_mode == 'asyn':
grpc_client_code_template_file = './templates/grpc_stream_client_code.py'

Expand Down
7 changes: 1 addition & 6 deletions antgo/script/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from .ssh_submit import *
from .local_submit import *
from .custom_submit import *
__all__=[
'ssh_submit_process_func',
'ssh_submit_resource_check_func',
'local_submit_process_func',
'local_submit_resource_check_func',
'custom_submit_process_func',
'custom_submit_resource_check_func'
'ssh_submit2_process_func'
]
62 changes: 0 additions & 62 deletions antgo/script/custom_submit.py

This file was deleted.

52 changes: 0 additions & 52 deletions antgo/script/local_submit.py

This file was deleted.

1 change: 1 addition & 0 deletions antgo/script/ssh-submit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ uploadProject(){
pidarr=()
for target_ip in ${target_ip_list[@]}
do
echo ${target_ip}
uploadProject ${username} ${target_ip} ${project} ${submit_time} &
pid=$!
pidarr+=(${pid})
Expand Down
153 changes: 152 additions & 1 deletion antgo/script/ssh_submit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
sys.path.insert(0, '/workspace/antgo')
import logging
import os
import time
Expand Down Expand Up @@ -138,6 +137,7 @@ def analyze_all_dependent_data(config_file_path):

return dependent_data_list


def ssh_submit_process_func(create_time, sys_argv, gpu_num, cpu_num, memory_size, task_name=None, ip='', exp='', check_data=False, env='master'):
# 前提假设,调用此函数前当前目录下需要存在项目代码
# 遍历所有注册的设备,找到每个设备的空闲GPU
Expand Down Expand Up @@ -339,6 +339,157 @@ def ssh_submit_process_func(create_time, sys_argv, gpu_num, cpu_num, memory_size
os.remove('./extra-config.py')
return True


def ssh_submit2_process_func(create_time, exe_script, base_image, gpu_num, cpu_num, memory_size, task_name=None, ip='', exp=''):
# 前提假设,调用此函数前当前目录下需要存在项目代码
# 遍历所有注册的设备,找到每个设备的空闲GPU
with open('./.project.json', 'r') as fp:
project_info = json.load(fp)

username = ''
password = ''
ssh_config_info = None
logging.info("Analyze cluster environment.")
if ip == '':
# 自动搜索可用远程机器
for file_name in os.listdir(os.path.join(os.environ['HOME'], '.config', 'antgo')):
register_ip = ''
if file_name.endswith('.yaml') and file_name.startswith('ssh'):
terms = file_name.split('-')
if len(terms) == 4:
register_ip = terms[1]
else:
continue

if register_ip == '':
continue

ssh_submit_config_file = os.path.join(os.environ['HOME'], '.config', 'antgo', file_name)
with open(ssh_submit_config_file, encoding='utf-8', mode='r') as fp:
ssh_config_info = yaml.safe_load(fp)

# 检查GPU占用情况
logging.info(f'Analyze IP: {register_ip}')
info = remote_gpu_running_info(ssh_config_info["config"]["username"], ssh_config_info["config"]["ip"])
if len(info['free_gpus']) >= gpu_num:
ip = ssh_config_info["config"]["ip"]
username = ssh_config_info["config"]["username"]
password = ssh_config_info["config"]["password"]
free_gpus = info['free_gpus']
break

target_ip_list = ip.split(',')
target_machine_info_list = []
for target_ip in target_ip_list:
ssh_submit_config_file = os.path.join(os.environ['HOME'], '.config', 'antgo', f'ssh-{target_ip}-submit-config.yaml')
with open(ssh_submit_config_file, encoding='utf-8', mode='r') as fp:
ssh_config_info = yaml.safe_load(fp)

# 检查GPU占用情况
info = remote_gpu_running_info(ssh_config_info["config"]["username"], ssh_config_info["config"]["ip"])
if len(info['free_gpus']) < gpu_num:
logging.error(f"No enough gpu in {target_ip}.")
return

username = ssh_config_info["config"]["username"]
password = ssh_config_info["config"]["password"]
target_machine_info_list.append({
'ip': target_ip,
'username': username,
'password': password,
'gpus': info['free_gpus']
})

if len(target_machine_info_list) != len(target_ip_list):
logging.error("No enough machine resource")
return
logging.info(f"Apply target machine resource {target_machine_info_list}")

apply_gpu_id = [str(i) for i in range(gpu_num)]
if len(target_machine_info_list) == 1:
apply_gpu_id = [str(target_machine_info_list[0]['gpus'][i]) for i in range(gpu_num)]
apply_gpu_id = ','.join(apply_gpu_id)

image_name = base_image
if image_name is None and ('image' in project_info and project_info['image'] != ''):
image_name = project_info['image']

if password == '':
password = 'default'

print(f'Use image {image_name}')
project_name = os.path.abspath(os.path.curdir).split("/")[-1]
submit_time = create_time
env = '-'

target_machine_ips = ','.join([v['ip'] for v in target_machine_info_list])
print(f'project_name {project_name}')
print(f'target_machine_ips {target_machine_ips}')

submit_script = os.path.join(os.path.dirname(__file__), 'ssh-submit.sh')
exe_script = f'{exe_script} --device-num={gpu_num} --nnodes={len(target_machine_info_list)} --master-port=8990 --master-addr={target_machine_info_list[0]["ip"]}'
submit_cmd = f'bash {submit_script} {username} {password} {target_machine_ips} {gpu_num} {cpu_num} {memory_size}M "{exe_script}" {image_name} {project_name} {env} {submit_time}'

# 解析提交后的输出,并解析出container id
print('submit command')
print(submit_cmd)
ret = subprocess.Popen(submit_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
content = ret.stdout.read()
content = content.decode('utf-8')
print(content)

# 检查返回的容器ID和机器IP对应关系
master_machine_info = target_machine_info_list[0]
container_id_list = content.split('\n')[(-1-len(target_machine_info_list)):-1]

master_container_id = ''
for container_id in container_id_list:
ssh_submit_config_file = os.path.join(os.environ['HOME'], '.config', 'antgo', f'ssh-{master_machine_info["ip"]}-submit-config.yaml')
with open(ssh_submit_config_file, encoding='utf-8', mode='r') as fp:
ssh_config_info = yaml.safe_load(fp)

cmd = f'ssh {ssh_config_info["config"]["username"]}@{ssh_config_info["config"]["ip"]} docker ps'
ret = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if ret.returncode:
logging.error("Couldnt get running info")
continue

running_info = ret.stdout.read()
running_info = running_info.decode('utf-8')
running_info = running_info.split('\n')
if len(running_info) <= 1:
logging.error(f"Couldnt parse container info on {master_machine_info['ip']}")
continue

is_found = False
for i in range(1, len(running_info)):
if running_info[i] == '':
continue

container_info = running_info[i].split(' ')
abs_container_id = container_info[0]
if container_id.startswith(abs_container_id):
is_found = True
break

if is_found:
master_container_id = container_id
break

if master_container_id == '':
logging.error('Couldnt find task container id.')
return False

# 获得container id
with open('./.project.json', 'r') as fp:
project_info = json.load(fp)
project_info['exp'][exp][-1]['id'] = master_container_id
project_info['exp'][exp][-1]['ip'] = target_machine_info_list[0]['ip']
with open('./.project.json', 'w') as fp:
json.dump(project_info,fp)

return True

# 检查任务资源是否满足
def ssh_submit_resource_check_func(gpu_num, cpu_num, memory_size):
# TODO,支持资源检查
Expand Down
Loading

0 comments on commit e0c3243

Please sign in to comment.