Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 52 additions & 29 deletions toolchest_client/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import threading
import time

import boto3
from botocore.exceptions import ClientError
import requests
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -164,11 +166,18 @@ def _add_input_file(self, input_file_path, input_prefix, input_is_in_s3=False):
try:
response.raise_for_status()
except HTTPError:
print(f"Failed to upload file at {input_file_path}", file=sys.stderr)
print(f"Failed to add input file at {input_file_path}", file=sys.stderr)
raise

if not input_is_in_s3:
return response.json().get("input_file_upload_location")
response_json = response.json()
return {
"access_key_id": response_json.get('access_key_id'),
"secret_access_key": response_json.get('secret_access_key'),
"session_token": response_json.get('session_token'),
"bucket": response_json.get('bucket'),
"object_name": response_json.get('object_name'),
}

def _upload(self, input_file_paths, input_prefix_mapping):
"""Uploads the files at ``input_file_paths`` to Toolchest."""
Expand All @@ -188,22 +197,28 @@ def _upload(self, input_file_paths, input_prefix_mapping):
)
else:
print(f"Uploading {file_path}")
upload_url = self._add_input_file(
input_file_keys = self._add_input_file(
input_file_path=file_path,
input_prefix=input_prefix_mapping.get(file_path),
input_is_in_s3=input_is_in_s3,
)

upload_response = requests.put(
upload_url,
data=open(file_path, "rb")
)
try:
upload_response.raise_for_status()
except HTTPError as e:
s3_client = boto3.client(
's3',
aws_access_key_id=input_file_keys["access_key_id"],
aws_secret_access_key=input_file_keys["secret_access_key"],
aws_session_token=input_file_keys["session_token"],
)
s3_client.upload_file(
file_path,
input_file_keys["bucket"],
input_file_keys["object_name"],
)
except ClientError as e:
# todo: this isn't propagating as a failure
self._update_status_to_failed(
f"Input file upload failed for file at {file_path}.",
f"{e} \n\nInput file upload failed for file at {file_path}.",
force_raise=True
)

Expand Down Expand Up @@ -309,31 +324,34 @@ def _get_job_status(self):
def _download(self, output_path):
"""Downloads output to ``output_path``."""

download_signed_url = self._get_download_signed_url()
output_file_keys = self._get_download()

self._update_status(Status.TRANSFERRING_TO_CLIENT)

# Downloads output by sending a GET request.
with requests.get(download_signed_url, stream=True) as r:
# Validates response of GET request.
try:
r.raise_for_status()
except HTTPError:
self._update_status_to_failed(
"Output download failed.",
force_raise=True,
)
try:
s3_client = boto3.client(
's3',
aws_access_key_id=output_file_keys["access_key_id"],
aws_secret_access_key=output_file_keys["secret_access_key"],
aws_session_token=output_file_keys["session_token"],
)
s3_client.upload_file(
output_file_keys["bucket"],
output_file_keys["object_name"],
output_path,
)

# Writes streamed output data from response to the output file.
with open(output_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
except ClientError as e:
# TODO: output more detailed error message if write error encountered
self._update_status_to_failed(
f"{e} \n\nOutput download failed.",
force_raise=True
)

self._update_status(Status.TRANSFERRED_TO_CLIENT)
self.mark_as_failed = False

def _get_download_signed_url(self):
def _get_download(self):
"""Gets URL for downloading output of query task(s)."""

response = requests.get(
Expand All @@ -348,9 +366,14 @@ def _get_download_signed_url(self):
force_raise=True,
)

# TODO: add support for multiple download files

return response.json()[0]["signed_url"]
response_json = response[0].json() # assumes only one output file
return {
"access_key_id": response_json.get('access_key_id'),
"secret_access_key": response_json.get('secret_access_key'),
"session_token": response_json.get('session_token'),
"bucket": response_json.get('bucket'),
"object_name": response_json.get('object_name'),
}

def _unpack_output(self, output_path, output_type):
"""After downloading, unpack files if needed"""
Expand Down