Skip to content

Commit

Permalink
simplify sse handling (#949)
Browse files Browse the repository at this point in the history
  • Loading branch information
balamurugana authored Aug 12, 2020
1 parent d4f0bde commit ae653ac
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 56 deletions.
62 changes: 23 additions & 39 deletions minio/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@
from .fold_case_dict import FoldCaseDict
from .helpers import (DEFAULT_PART_SIZE, MAX_MULTIPART_COUNT, MAX_PART_SIZE,
MAX_POOL_SIZE, MIN_PART_SIZE, amzprefix_user_metadata,
check_bucket_name, check_non_empty_string, dump_http,
get_md5_base64digest, get_s3_region_from_endpoint,
get_scheme_host, get_sha256_hexdigest, get_target_url,
is_amz_header, is_supported_header, is_valid_endpoint,
check_bucket_name, check_non_empty_string, check_sse,
check_ssec, dump_http, get_md5_base64digest,
get_s3_region_from_endpoint, get_scheme_host,
get_sha256_hexdigest, get_target_url, is_amz_header,
is_supported_header, is_valid_endpoint,
is_valid_notification_config, is_valid_policy_type,
is_valid_sse_c_object, is_valid_sse_object, mkdir_p,
optimal_part_info, read_full)
mkdir_p, optimal_part_info, read_full)
from .parsers import (parse_assume_role, parse_copy_object,
parse_get_bucket_notification, parse_list_buckets,
parse_list_multipart_uploads, parse_list_object_versions,
Expand All @@ -70,7 +70,6 @@
from .signer import (_SIGN_V4_ALGORITHM, _UNSIGNED_PAYLOAD,
generate_credential_string, post_presign_signature,
presign_v4, sign_v4)
from .sse import SseCustomerKey
from .thread_pool import ThreadPool
from .xml_marshal import (marshal_bucket_notifications,
marshal_complete_multipart,
Expand Down Expand Up @@ -880,17 +879,15 @@ def get_object(self, bucket_name, object_name, offset=0, length=0,
"""
check_bucket_name(bucket_name)
check_non_empty_string(object_name)
is_valid_sse_c_object(sse)
check_ssec(sse)

headers = request_headers or {}
headers = sse.headers() if sse else {}
headers.update(request_headers or {})

if offset or length:
headers['Range'] = 'bytes={}-{}'.format(
offset, offset + length - 1 if length else "")

if sse:
headers.update(sse.headers())

if version_id:
extra_query_params = extra_query_params or {}
extra_query_params["versionId"] = version_id
Expand Down Expand Up @@ -949,15 +946,13 @@ def copy_object(self, bucket_name, object_name, object_source,
if conditions:
headers.update(conditions)

# Source argument to copy_object can only be of type SSE-C
if source_sse:
is_valid_sse_c_object(source_sse)
headers.update(source_sse.copy_headers())
# Source sse must be SSE-C if not null.
check_ssec(source_sse)
headers.update(source_sse.copy_headers() if source_sse else {})

# Destination argument to copy_object cannot be of type SSE-C
if sse:
is_valid_sse_object(sse)
headers.update(sse.headers())
# Destination sse can be any Sse type if not null.
check_sse(sse)
headers.update(sse.headers() if sse else {})

headers['X-Amz-Copy-Source'] = quote(object_source)

Expand Down Expand Up @@ -993,7 +988,7 @@ def put_object(self, bucket_name, object_name, data, length,
'foo', 'bar', data, file_stat.st_size, 'text/plain',
)
"""
is_valid_sse_object(sse)
check_sse(sse)
check_bucket_name(bucket_name)
check_non_empty_string(object_name)

Expand Down Expand Up @@ -1164,11 +1159,9 @@ def stat_object(self, bucket_name, object_name, sse=None, version_id=None,

check_bucket_name(bucket_name)
check_non_empty_string(object_name)
check_ssec(sse)

headers = {}
if sse:
is_valid_sse_c_object(sse)
headers = sse.headers()
headers = sse.headers() if sse else {}

if version_id:
if extra_query_params:
Expand Down Expand Up @@ -1730,21 +1723,15 @@ def _do_put_object(self, bucket_name, object_name, part_data,
if md5_base64:
headers['Content-Md5'] = md5_base64

if metadata:
headers.update(metadata)
headers.update(metadata or {})
headers.update(sse.headers() if sse else {})

query = {}
if part_number > 0 and upload_id:
if part_number and upload_id:
query = {
'uploadId': upload_id,
'partNumber': str(part_number),
}
# Encryption headers for multipart uploads should
# be set only in the case of SSE-C.
if sse and isinstance(sse, SseCustomerKey):
headers.update(sse.headers())
elif sse:
headers.update(sse.headers())

response = self._url_open(
'PUT',
Expand Down Expand Up @@ -1905,11 +1892,8 @@ def _new_multipart_upload(self, bucket_name, object_name,
check_bucket_name(bucket_name)
check_non_empty_string(object_name)

headers = {}
if metadata:
headers.update(metadata)
if sse:
headers.update(sse.headers())
headers = metadata or {}
headers.update(sse.headers() if sse else {})

response = self._url_open('POST', bucket_name=bucket_name,
object_name=object_name,
Expand Down
22 changes: 6 additions & 16 deletions minio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,26 +550,16 @@ def _check_value(value, key):
return True


def is_valid_sse_c_object(sse):
"""
Validate the SSE object and type
:param sse: SSE object defined.
"""
def check_ssec(sse):
"""Check sse is SseCustomerKey type or not."""
if sse and not isinstance(sse, SseCustomerKey):
raise InvalidArgumentError(
"Required type SSE-C object to be passed")
raise InvalidArgumentError("SseCustomerKey type is required")


def is_valid_sse_object(sse):
"""
Validate the SSE object and type
:param sse: SSE object defined.
"""
def check_sse(sse):
"""Check sse is Sse type or not."""
if sse and not isinstance(sse, Sse):
raise InvalidArgumentError(
"unsuported type of sse argument in put_object")
raise InvalidArgumentError("Sse type is required")


def encode_object_name(object_name):
Expand Down
9 changes: 8 additions & 1 deletion minio/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ class Sse:
def headers(self):
"""Return headers."""

def tls_required(self): # pylint: disable=no-self-use
"""Return TLS required to use this server-side encryption."""
return True

def copy_headers(self): # pylint: disable=no-self-use
"""Return copy headers."""
raise TypeError("method unsupported")
return {}


class SseCustomerKey(Sse):
Expand Down Expand Up @@ -100,3 +104,6 @@ def headers(self):
return {
"X-Amz-Server-Side-Encryption": "AES256"
}

def tls_required(self):
return False

0 comments on commit ae653ac

Please sign in to comment.