-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add "trainremote" command #906
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -464,6 +464,169 @@ def run(self, args): | |
print(f"Unrecognized framework: {framework}. Please specify one of " | ||
f"'tensorflow' or 'pytorch'") | ||
|
||
class TrainRemote(BaseCommand): | ||
""" | ||
train models on a remote server | ||
""" | ||
# example command: | ||
# donkey trainremote --tub ~/mycar/data/tub_4_21-06-11/ ~/mycar/data/tub_3_21-06-11/ --path ~/mycar --model modelname | ||
# --url https://hq.robocarstore.com/train/submit_job --get https://hq.robocarstore.com/train/refresh_job_statuses | ||
|
||
WLAN="wlan0" | ||
HOTSPOT_IF_NAME="uap0" | ||
REFRESH_JOB_STATUS_URL = "https://hq.robocarstore.com/train/refresh_job_statuses" | ||
SUBMIT_JOB_URL = "https://hq.robocarstore.com/train/submit_job" | ||
TIMEOUT = 900 # limit the training to 15min | ||
|
||
def parse_args(self, args): | ||
parser = argparse.ArgumentParser(prog='train', usage='%(prog)s [options]') | ||
parser.add_argument('--tub', nargs='+', help='tub data for training') | ||
parser.add_argument('--path', default=None, help='path where to create car folder') | ||
parser.add_argument('--url', default=None, help='url of the remote server to submit job') | ||
parser.add_argument('--get', default=None, help='url of the remote server to get statuses') | ||
parser.add_argument('--model', default=None, help='output model name') | ||
|
||
parsed_args = parser.parse_args(args) | ||
return parsed_args | ||
|
||
def generate_tub_archive(self, tub_paths, carapp_path): | ||
import tempfile | ||
import tarfile | ||
from pathlib import Path | ||
print("generating tub archive") | ||
f = tempfile.NamedTemporaryFile(mode='w+b', suffix='.tar.gz', delete=False) | ||
|
||
with tarfile.open(fileobj=f, mode='w:gz') as tar: | ||
for tub_path in tub_paths: | ||
p = Path(tub_path) | ||
tar.add(p, arcname=p.name) | ||
tar.add(f"{carapp_path}/myconfig.py", arcname="myconfig.py") | ||
|
||
f.close() | ||
|
||
return f.name | ||
|
||
def get_wlan_mac_address(self, wlan): | ||
import netifaces | ||
|
||
interfaces = netifaces.interfaces() | ||
if wlan not in interfaces: | ||
return "None" | ||
addrs = netifaces.ifaddresses(wlan) | ||
|
||
if addrs is None: | ||
return "None" | ||
|
||
if (netifaces.AF_LINK in addrs) and (len(addrs[netifaces.AF_LINK]) == 1): | ||
return addrs[netifaces.AF_LINK][0]['addr'] | ||
else: | ||
return "None" | ||
|
||
def submit_job(self, tub_paths, carapp_path, submit_job_url): | ||
import requests | ||
from requests_toolbelt.multipart.encoder import MultipartEncoder | ||
from donkeycar import __version__ | ||
|
||
filename = self.generate_tub_archive(tub_paths, carapp_path) | ||
mp_encoder = MultipartEncoder( | ||
fields={ | ||
'device_id': self.get_wlan_mac_address(self.WLAN), | ||
'hostname' : gethostname(), | ||
'tub_archive_file': ('file.tar.gz', open(filename, 'rb'), 'application/gzip'), | ||
'donkeycar_version': str(__version__) | ||
} | ||
) | ||
print(f"URL submitted to: {submit_job_url}") | ||
print(f"Data to submit: {mp_encoder}") | ||
r = requests.post( | ||
submit_job_url, | ||
data=mp_encoder, # The MultipartEncoder is posted as data, don't use files=...! | ||
# The MultipartEncoder provides the content-type header with the boundary: | ||
headers={'Content-Type': mp_encoder.content_type} | ||
) | ||
print(f"Submission response: HTTP {r.status_code}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here too |
||
if (r.status_code == 200): | ||
# if HTTP 200 OK | ||
if ("job_uuid" in r.json()): | ||
try: | ||
uuid = r.json()['job_uuid'] | ||
print(f"Submitted a Training Job to the remote server {submit_job_url}\n uuid: {uuid}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and here |
||
return uuid | ||
except Exception as e: | ||
print(e) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and here |
||
raise Exception("Failed to call submit job") | ||
else: | ||
raise Exception("Failed to call submit job") | ||
else: | ||
raise Exception("Failed to call submit job") | ||
|
||
def get_latest_job_status_from_hq(self, refresh_job_statuses, job_uuids): | ||
import requests | ||
|
||
print(f"Getting lastest job status for uuid {job_uuids}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... |
||
response = requests.post(refresh_job_statuses, data={"job_uuids": job_uuids}) | ||
if response.status_code == 200: | ||
return response.json() | ||
else: | ||
print(response.status_code) | ||
print(response.content) | ||
raise Exception("Problem requesting latest job status from hq") | ||
|
||
def run(self, args): | ||
from datetime import datetime | ||
import time | ||
|
||
args = self.parse_args(args) | ||
tub_paths = list(args.tub) | ||
car_path = make_dir(args.path or './') | ||
if args.url: | ||
url = args.url | ||
else: | ||
url = self.SUBMIT_JOB_URL | ||
if args.get: | ||
url_status = args.get | ||
else: | ||
url_status = self.REFRESH_JOB_STATUS_URL | ||
|
||
uuid = self.submit_job(tub_paths, car_path, url) | ||
if uuid: | ||
if args.model: | ||
model_name = os.path.splitext(os.path.basename(args.model))[0] # get the filename without base path and extension | ||
else: | ||
model_name = f"job_{uuid}" | ||
|
||
start_time = time.time() | ||
run_time = 0 | ||
checkagain = True | ||
while checkagain and (run_time < self.TIMEOUT): | ||
result = self.get_latest_job_status_from_hq(url_status, uuid)[0] | ||
checkagain = result['status'] == "SCHEDULED" or result['status'] == "TRAINING" | ||
print(f"Training Status:{result['status']} at {datetime.now().strftime('%Y-%m-%d, %H:%M:%S')}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... |
||
|
||
time.sleep(10) | ||
if result['status'] == "COMPLETED": | ||
print("Training completed!!") | ||
print(f"Model URL: {result['model_url']}") | ||
print(f"Model Accuracy URL: {result['model_accuracy_url']}") | ||
print(f"Model Movie URL: {result['model_movie_url']}") | ||
|
||
if not os.path.isdir(f"{car_path}/movies/"): | ||
print(f"Creating movie folder '{car_path}/movies/' ") | ||
os.mkdir(f"{car_path}/movies/") | ||
|
||
command = ["curl", "--fail", result['model_url'], "--output", f"{car_path}/models/{model_name}.h5"] | ||
proc = subprocess.Popen(command) | ||
command = ["curl", "--fail", result['model_accuracy_url'], "--output", f"{car_path}/models/{model_name}.png"] | ||
proc = subprocess.Popen(command) | ||
command = ["curl", "--fail", result['model_movie_url'], "--output", f"{car_path}/movies/{model_name}.mp4"] | ||
proc = subprocess.Popen(command) | ||
print("Downloaded model, accuracy graph and movie from the URLs") | ||
break | ||
elif result['status'] in ["NO_CAPACITY", "NO_QUOTA", "SPOT_REQUEST_FAILED", "TIMEOUT"]: | ||
raise Exception(f"Failed to train the submitted job\nError : {result['status']}") | ||
|
||
run_time = time.time()-start_time | ||
print(f"Time spent: {run_time} s") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... |
||
|
||
class Gui(BaseCommand): | ||
def run(self, args): | ||
|
@@ -486,6 +649,7 @@ def execute_from_command_line(): | |
'cnnactivations': ShowCnnActivations, | ||
'update': UpdateCar, | ||
'train': Train, | ||
'trainremote': TrainRemote, | ||
'ui': Gui, | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,9 @@ dependencies: | |
- psutil | ||
- kivy=2.0.0 | ||
- plotly | ||
- requests | ||
- requests_toolbelt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems we can't find this package in conda main and conda forge channels, maybe it's only available in pip? |
||
- netifaces | ||
- pip: | ||
- tensorflow==2.2.0 | ||
- git+https://github.com/autorope/keras-vis.git | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,9 @@ dependencies: | |
- kivy=2.0.0 | ||
- plotly | ||
- tensorflow==2.2.0 | ||
- requests | ||
- requests_toolbelt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to produce the same problem as in the Mac conda yaml file. |
||
- netifaces | ||
- pip: | ||
- git+https://github.com/autorope/keras-vis.git | ||
- simple-pid | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,9 @@ dependencies: | |
- kivy=2.0.0 | ||
- plotly | ||
- psutil | ||
- requests | ||
- requests_toolbelt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No idea if this works in windows... do you have any means to check that? |
||
- netifaces | ||
- pip: | ||
- git+https://github.com/autorope/keras-vis.git | ||
- simple-pid | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please switch to logging statements instead of print, that's what we are using now.