Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "trainremote" command #906

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions donkeycar/management/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,169 @@ def run(self, args):
print(f"Unrecognized framework: {framework}. Please specify one of "
f"'tensorflow' or 'pytorch'")

class TrainRemote(BaseCommand):
"""
train models on a remote server
"""
# example command:
# donkey trainremote --tub ~/mycar/data/tub_4_21-06-11/ ~/mycar/data/tub_3_21-06-11/ --path ~/mycar --model modelname
# --url https://hq.robocarstore.com/train/submit_job --get https://hq.robocarstore.com/train/refresh_job_statuses

WLAN="wlan0"
HOTSPOT_IF_NAME="uap0"
REFRESH_JOB_STATUS_URL = "https://hq.robocarstore.com/train/refresh_job_statuses"
SUBMIT_JOB_URL = "https://hq.robocarstore.com/train/submit_job"
TIMEOUT = 900 # limit the training to 15min

def parse_args(self, args):
parser = argparse.ArgumentParser(prog='train', usage='%(prog)s [options]')
parser.add_argument('--tub', nargs='+', help='tub data for training')
parser.add_argument('--path', default=None, help='path where to create car folder')
parser.add_argument('--url', default=None, help='url of the remote server to submit job')
parser.add_argument('--get', default=None, help='url of the remote server to get statuses')
parser.add_argument('--model', default=None, help='output model name')

parsed_args = parser.parse_args(args)
return parsed_args

def generate_tub_archive(self, tub_paths, carapp_path):
import tempfile
import tarfile
from pathlib import Path
print("generating tub archive")
f = tempfile.NamedTemporaryFile(mode='w+b', suffix='.tar.gz', delete=False)

with tarfile.open(fileobj=f, mode='w:gz') as tar:
for tub_path in tub_paths:
p = Path(tub_path)
tar.add(p, arcname=p.name)
tar.add(f"{carapp_path}/myconfig.py", arcname="myconfig.py")

f.close()

return f.name

def get_wlan_mac_address(self, wlan):
import netifaces

interfaces = netifaces.interfaces()
if wlan not in interfaces:
return "None"
addrs = netifaces.ifaddresses(wlan)

if addrs is None:
return "None"

if (netifaces.AF_LINK in addrs) and (len(addrs[netifaces.AF_LINK]) == 1):
return addrs[netifaces.AF_LINK][0]['addr']
else:
return "None"

def submit_job(self, tub_paths, carapp_path, submit_job_url):
import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder
from donkeycar import __version__

filename = self.generate_tub_archive(tub_paths, carapp_path)
mp_encoder = MultipartEncoder(
fields={
'device_id': self.get_wlan_mac_address(self.WLAN),
'hostname' : gethostname(),
'tub_archive_file': ('file.tar.gz', open(filename, 'rb'), 'application/gzip'),
'donkeycar_version': str(__version__)
}
)
print(f"URL submitted to: {submit_job_url}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please switch to logging statements instead of print, that's what we are using now.

print(f"Data to submit: {mp_encoder}")
r = requests.post(
submit_job_url,
data=mp_encoder, # The MultipartEncoder is posted as data, don't use files=...!
# The MultipartEncoder provides the content-type header with the boundary:
headers={'Content-Type': mp_encoder.content_type}
)
print(f"Submission response: HTTP {r.status_code}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

if (r.status_code == 200):
# if HTTP 200 OK
if ("job_uuid" in r.json()):
try:
uuid = r.json()['job_uuid']
print(f"Submitted a Training Job to the remote server {submit_job_url}\n uuid: {uuid}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

return uuid
except Exception as e:
print(e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

raise Exception("Failed to call submit job")
else:
raise Exception("Failed to call submit job")
else:
raise Exception("Failed to call submit job")

def get_latest_job_status_from_hq(self, refresh_job_statuses, job_uuids):
import requests

print(f"Getting lastest job status for uuid {job_uuids}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...

response = requests.post(refresh_job_statuses, data={"job_uuids": job_uuids})
if response.status_code == 200:
return response.json()
else:
print(response.status_code)
print(response.content)
raise Exception("Problem requesting latest job status from hq")

def run(self, args):
from datetime import datetime
import time

args = self.parse_args(args)
tub_paths = list(args.tub)
car_path = make_dir(args.path or './')
if args.url:
url = args.url
else:
url = self.SUBMIT_JOB_URL
if args.get:
url_status = args.get
else:
url_status = self.REFRESH_JOB_STATUS_URL

uuid = self.submit_job(tub_paths, car_path, url)
if uuid:
if args.model:
model_name = os.path.splitext(os.path.basename(args.model))[0] # get the filename without base path and extension
else:
model_name = f"job_{uuid}"

start_time = time.time()
run_time = 0
checkagain = True
while checkagain and (run_time < self.TIMEOUT):
result = self.get_latest_job_status_from_hq(url_status, uuid)[0]
checkagain = result['status'] == "SCHEDULED" or result['status'] == "TRAINING"
print(f"Training Status:{result['status']} at {datetime.now().strftime('%Y-%m-%d, %H:%M:%S')}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...


time.sleep(10)
if result['status'] == "COMPLETED":
print("Training completed!!")
print(f"Model URL: {result['model_url']}")
print(f"Model Accuracy URL: {result['model_accuracy_url']}")
print(f"Model Movie URL: {result['model_movie_url']}")

if not os.path.isdir(f"{car_path}/movies/"):
print(f"Creating movie folder '{car_path}/movies/' ")
os.mkdir(f"{car_path}/movies/")

command = ["curl", "--fail", result['model_url'], "--output", f"{car_path}/models/{model_name}.h5"]
proc = subprocess.Popen(command)
command = ["curl", "--fail", result['model_accuracy_url'], "--output", f"{car_path}/models/{model_name}.png"]
proc = subprocess.Popen(command)
command = ["curl", "--fail", result['model_movie_url'], "--output", f"{car_path}/movies/{model_name}.mp4"]
proc = subprocess.Popen(command)
print("Downloaded model, accuracy graph and movie from the URLs")
break
elif result['status'] in ["NO_CAPACITY", "NO_QUOTA", "SPOT_REQUEST_FAILED", "TIMEOUT"]:
raise Exception(f"Failed to train the submitted job\nError : {result['status']}")

run_time = time.time()-start_time
print(f"Time spent: {run_time} s")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...


class Gui(BaseCommand):
def run(self, args):
Expand All @@ -486,6 +649,7 @@ def execute_from_command_line():
'cnnactivations': ShowCnnActivations,
'update': UpdateCar,
'train': Train,
'trainremote': TrainRemote,
'ui': Gui,
}

Expand Down
3 changes: 3 additions & 0 deletions install/envs/mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ dependencies:
- psutil
- kivy=2.0.0
- plotly
- requests
- requests_toolbelt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we can't find this package in conda main and conda forge channels, maybe it's only available in pip?

- netifaces
- pip:
- tensorflow==2.2.0
- git+https://github.com/autorope/keras-vis.git
Expand Down
3 changes: 3 additions & 0 deletions install/envs/ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ dependencies:
- kivy=2.0.0
- plotly
- tensorflow==2.2.0
- requests
- requests_toolbelt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to produce the same problem as in the Mac conda yaml file.

- netifaces
- pip:
- git+https://github.com/autorope/keras-vis.git
- simple-pid
Expand Down
3 changes: 3 additions & 0 deletions install/envs/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ dependencies:
- kivy=2.0.0
- plotly
- psutil
- requests
- requests_toolbelt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea if this works in windows... do you have any means to check that?

- netifaces
- pip:
- git+https://github.com/autorope/keras-vis.git
- simple-pid
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def package_files(directory, strip_leading):
'progress',
'typing_extensions',
'pyfiglet',
'psutil'
'psutil',
'requests_toolbelt',
'netifaces'
],
extras_require={
'pi': [
Expand Down