Skip to content

Commit 2f92c3a

Browse files
cojencoMiaCY
andauthored
feat: add support for custom headers (#1121)
* Chore: refactor client.download_blob_to_file (#1052) * Refactor client.download_blob_to_file * Chore: clean up code * refactor blob and client unit tests * lint reformat * Rename _prep_and_do_download * Chore: refactor blob.upload_from_file (#1063) * Refactor client.download_blob_to_file * Chore: clean up code * refactor blob and client unit tests * lint reformat * Rename _prep_and_do_download * Refactor blob.upload_from_file * Lint reformat * feature: add 'command' argument to private upload/download interface (#1082) * Refactor client.download_blob_to_file * Chore: clean up code * refactor blob and client unit tests * lint reformat * Rename _prep_and_do_download * Refactor blob.upload_from_file * Lint reformat * feature: add 'command' argument to private upload/download interface * lint reformat * reduce duplication and edit docstring * feat: add support for custom headers starting with metadata op * add custom headers to downloads in client blob modules * add custom headers to uploads with tests * update mocks and tests * test custom headers support tm mpu uploads * update tm test * update test --------- Co-authored-by: MiaCY <97990237+MiaCY@users.noreply.github.com>
1 parent 1ef0e1a commit 2f92c3a

File tree

7 files changed

+219
-27
lines changed

7 files changed

+219
-27
lines changed

google/cloud/storage/blob.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,11 +1738,13 @@ def _get_upload_arguments(self, client, content_type, filename=None, command=Non
17381738
* The ``content_type`` as a string (according to precedence)
17391739
"""
17401740
content_type = self._get_content_type(content_type, filename=filename)
1741+
# Add any client attached custom headers to the upload headers.
17411742
headers = {
17421743
**_get_default_headers(
17431744
client._connection.user_agent, content_type, command=command
17441745
),
17451746
**_get_encryption_headers(self._encryption_key),
1747+
**client._extra_headers,
17461748
}
17471749
object_metadata = self._get_writable_metadata()
17481750
return headers, object_metadata, content_type
@@ -4313,9 +4315,11 @@ def _prep_and_do_download(
43134315
if_etag_match=if_etag_match,
43144316
if_etag_not_match=if_etag_not_match,
43154317
)
4318+
# Add any client attached custom headers to be sent with the request.
43164319
headers = {
43174320
**_get_default_headers(client._connection.user_agent, command=command),
43184321
**headers,
4322+
**client._extra_headers,
43194323
}
43204324

43214325
transport = client._http

google/cloud/storage/client.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ class Client(ClientWithProject):
9494
(Optional) Whether authentication is required under custom endpoints.
9595
If false, uses AnonymousCredentials and bypasses authentication.
9696
Defaults to True. Note this is only used when a custom endpoint is set in conjunction.
97+
98+
:type extra_headers: dict
99+
:param extra_headers:
100+
(Optional) Custom headers to be sent with the requests attached to the client.
101+
For example, you can add custom audit logging headers.
97102
"""
98103

99104
SCOPE = (
@@ -111,6 +116,7 @@ def __init__(
111116
client_info=None,
112117
client_options=None,
113118
use_auth_w_custom_endpoint=True,
119+
extra_headers={},
114120
):
115121
self._base_connection = None
116122

@@ -127,6 +133,7 @@ def __init__(
127133
# are passed along, for use in __reduce__ defined elsewhere.
128134
self._initial_client_info = client_info
129135
self._initial_client_options = client_options
136+
self._extra_headers = extra_headers
130137

131138
kw_args = {"client_info": client_info}
132139

@@ -172,7 +179,10 @@ def __init__(
172179
if no_project:
173180
self.project = None
174181

175-
self._connection = Connection(self, **kw_args)
182+
# Pass extra_headers to Connection
183+
connection = Connection(self, **kw_args)
184+
connection.extra_headers = extra_headers
185+
self._connection = connection
176186
self._batch_stack = _LocalStack()
177187

178188
@classmethod

google/cloud/storage/transfer_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,7 @@ def _reduce_client(cl):
12891289
_http = None # Can't carry this over
12901290
client_info = cl._initial_client_info
12911291
client_options = cl._initial_client_options
1292+
extra_headers = cl._extra_headers
12921293

12931294
return _LazyClient, (
12941295
client_object_id,
@@ -1297,6 +1298,7 @@ def _reduce_client(cl):
12971298
_http,
12981299
client_info,
12991300
client_options,
1301+
extra_headers,
13001302
)
13011303

13021304

tests/unit/test__http.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,55 @@ def test_extra_headers(self):
7171
timeout=_DEFAULT_TIMEOUT,
7272
)
7373

74+
def test_metadata_op_has_client_custom_headers(self):
75+
import requests
76+
import google.auth.credentials
77+
from google.cloud import _http as base_http
78+
from google.cloud.storage import Client
79+
from google.cloud.storage.constants import _DEFAULT_TIMEOUT
80+
81+
custom_headers = {
82+
"x-goog-custom-audit-foo": "bar",
83+
"x-goog-custom-audit-user": "baz",
84+
}
85+
http = mock.create_autospec(requests.Session, instance=True)
86+
response = requests.Response()
87+
response.status_code = 200
88+
data = b"brent-spiner"
89+
response._content = data
90+
http.is_mtls = False
91+
http.request.return_value = response
92+
credentials = mock.Mock(spec=google.auth.credentials.Credentials)
93+
client = Client(
94+
project="project",
95+
credentials=credentials,
96+
_http=http,
97+
extra_headers=custom_headers,
98+
)
99+
req_data = "hey-yoooouuuuu-guuuuuyyssss"
100+
with patch.object(
101+
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
102+
):
103+
result = client._connection.api_request(
104+
"GET", "/rainbow", data=req_data, expect_json=False
105+
)
106+
self.assertEqual(result, data)
107+
108+
expected_headers = {
109+
**custom_headers,
110+
"Accept-Encoding": "gzip",
111+
base_http.CLIENT_INFO_HEADER: f"{client._connection.user_agent} {GCCL_INVOCATION_TEST_CONST}",
112+
"User-Agent": client._connection.user_agent,
113+
}
114+
expected_uri = client._connection.build_api_url("/rainbow")
115+
http.request.assert_called_once_with(
116+
data=req_data,
117+
headers=expected_headers,
118+
method="GET",
119+
url=expected_uri,
120+
timeout=_DEFAULT_TIMEOUT,
121+
)
122+
74123
def test_build_api_url_no_extra_query_params(self):
75124
from urllib.parse import parse_qsl
76125
from urllib.parse import urlsplit

tests/unit/test_blob.py

Lines changed: 125 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,8 +2246,13 @@ def test__set_metadata_to_none(self):
22462246
def test__get_upload_arguments(self):
22472247
name = "blob-name"
22482248
key = b"[pXw@,p@@AfBfrR3x-2b2SCHR,.?YwRO"
2249+
custom_headers = {
2250+
"x-goog-custom-audit-foo": "bar",
2251+
"x-goog-custom-audit-user": "baz",
2252+
}
22492253
client = mock.Mock(_connection=_Connection)
22502254
client._connection.user_agent = "testing 1.2.3"
2255+
client._extra_headers = custom_headers
22512256
blob = self._make_one(name, bucket=None, encryption_key=key)
22522257
blob.content_disposition = "inline"
22532258

@@ -2271,6 +2276,7 @@ def test__get_upload_arguments(self):
22712276
"X-Goog-Encryption-Algorithm": "AES256",
22722277
"X-Goog-Encryption-Key": header_key_value,
22732278
"X-Goog-Encryption-Key-Sha256": header_key_hash_value,
2279+
**custom_headers,
22742280
}
22752281
self.assertEqual(
22762282
headers["X-Goog-API-Client"],
@@ -2325,6 +2331,7 @@ def _do_multipart_success(
23252331

23262332
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
23272333
client._connection.API_BASE_URL = "https://storage.googleapis.com"
2334+
client._extra_headers = {}
23282335

23292336
# Mock get_api_base_url_for_mtls function.
23302337
mtls_url = "https://foo.mtls"
@@ -2424,11 +2431,14 @@ def _do_multipart_success(
24242431
with patch.object(
24252432
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
24262433
):
2427-
headers = _get_default_headers(
2428-
client._connection.user_agent,
2429-
b'multipart/related; boundary="==0=="',
2430-
"application/xml",
2431-
)
2434+
headers = {
2435+
**_get_default_headers(
2436+
client._connection.user_agent,
2437+
b'multipart/related; boundary="==0=="',
2438+
"application/xml",
2439+
),
2440+
**client._extra_headers,
2441+
}
24322442
client._http.request.assert_called_once_with(
24332443
"POST", upload_url, data=payload, headers=headers, timeout=expected_timeout
24342444
)
@@ -2520,6 +2530,19 @@ def test__do_multipart_upload_with_client(self, mock_get_boundary):
25202530
transport = self._mock_transport(http.client.OK, {})
25212531
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
25222532
client._connection.API_BASE_URL = "https://storage.googleapis.com"
2533+
client._extra_headers = {}
2534+
self._do_multipart_success(mock_get_boundary, client=client)
2535+
2536+
@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
2537+
def test__do_multipart_upload_with_client_custom_headers(self, mock_get_boundary):
2538+
custom_headers = {
2539+
"x-goog-custom-audit-foo": "bar",
2540+
"x-goog-custom-audit-user": "baz",
2541+
}
2542+
transport = self._mock_transport(http.client.OK, {})
2543+
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
2544+
client._connection.API_BASE_URL = "https://storage.googleapis.com"
2545+
client._extra_headers = custom_headers
25232546
self._do_multipart_success(mock_get_boundary, client=client)
25242547

25252548
@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
@@ -2597,6 +2620,7 @@ def _initiate_resumable_helper(
25972620
# Create some mock arguments and call the method under test.
25982621
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
25992622
client._connection.API_BASE_URL = "https://storage.googleapis.com"
2623+
client._extra_headers = {}
26002624

26012625
# Mock get_api_base_url_for_mtls function.
26022626
mtls_url = "https://foo.mtls"
@@ -2677,13 +2701,15 @@ def _initiate_resumable_helper(
26772701
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
26782702
):
26792703
if extra_headers is None:
2680-
self.assertEqual(
2681-
upload._headers,
2682-
_get_default_headers(client._connection.user_agent, content_type),
2683-
)
2704+
expected_headers = {
2705+
**_get_default_headers(client._connection.user_agent, content_type),
2706+
**client._extra_headers,
2707+
}
2708+
self.assertEqual(upload._headers, expected_headers)
26842709
else:
26852710
expected_headers = {
26862711
**_get_default_headers(client._connection.user_agent, content_type),
2712+
**client._extra_headers,
26872713
**extra_headers,
26882714
}
26892715
self.assertEqual(upload._headers, expected_headers)
@@ -2730,9 +2756,12 @@ def _initiate_resumable_helper(
27302756
with patch.object(
27312757
_helpers, "_get_invocation_id", return_value=GCCL_INVOCATION_TEST_CONST
27322758
):
2733-
expected_headers = _get_default_headers(
2734-
client._connection.user_agent, x_upload_content_type=content_type
2735-
)
2759+
expected_headers = {
2760+
**_get_default_headers(
2761+
client._connection.user_agent, x_upload_content_type=content_type
2762+
),
2763+
**client._extra_headers,
2764+
}
27362765
if size is not None:
27372766
expected_headers["x-upload-content-length"] = str(size)
27382767
if extra_headers is not None:
@@ -2824,6 +2853,21 @@ def test__initiate_resumable_upload_with_client(self):
28242853

28252854
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
28262855
client._connection.API_BASE_URL = "https://storage.googleapis.com"
2856+
client._extra_headers = {}
2857+
self._initiate_resumable_helper(client=client)
2858+
2859+
def test__initiate_resumable_upload_with_client_custom_headers(self):
2860+
custom_headers = {
2861+
"x-goog-custom-audit-foo": "bar",
2862+
"x-goog-custom-audit-user": "baz",
2863+
}
2864+
resumable_url = "http://test.invalid?upload_id=hey-you"
2865+
response_headers = {"location": resumable_url}
2866+
transport = self._mock_transport(http.client.OK, response_headers)
2867+
2868+
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
2869+
client._connection.API_BASE_URL = "https://storage.googleapis.com"
2870+
client._extra_headers = custom_headers
28272871
self._initiate_resumable_helper(client=client)
28282872

28292873
def _make_resumable_transport(
@@ -3000,6 +3044,7 @@ def _do_resumable_helper(
30003044
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
30013045
client._connection.API_BASE_URL = "https://storage.googleapis.com"
30023046
client._connection.user_agent = USER_AGENT
3047+
client._extra_headers = {}
30033048
stream = io.BytesIO(data)
30043049

30053050
bucket = _Bucket(name="yesterday")
@@ -3612,26 +3657,32 @@ def _create_resumable_upload_session_helper(
36123657
if_metageneration_match=None,
36133658
if_metageneration_not_match=None,
36143659
retry=None,
3660+
client=None,
36153661
):
36163662
bucket = _Bucket(name="alex-trebek")
36173663
blob = self._make_one("blob-name", bucket=bucket)
36183664
chunk_size = 99 * blob._CHUNK_SIZE_MULTIPLE
36193665
blob.chunk_size = chunk_size
3620-
3621-
# Create mocks to be checked for doing transport.
36223666
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
3623-
response_headers = {"location": resumable_url}
3624-
transport = self._mock_transport(http.client.OK, response_headers)
3625-
if side_effect is not None:
3626-
transport.request.side_effect = side_effect
3627-
3628-
# Create some mock arguments and call the method under test.
36293667
content_type = "text/plain"
36303668
size = 10000
3631-
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
3632-
client._connection.API_BASE_URL = "https://storage.googleapis.com"
3633-
client._connection.user_agent = "testing 1.2.3"
3669+
transport = None
36343670

3671+
if not client:
3672+
# Create mocks to be checked for doing transport.
3673+
response_headers = {"location": resumable_url}
3674+
transport = self._mock_transport(http.client.OK, response_headers)
3675+
3676+
# Create some mock arguments and call the method under test.
3677+
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
3678+
client._connection.API_BASE_URL = "https://storage.googleapis.com"
3679+
client._connection.user_agent = "testing 1.2.3"
3680+
client._extra_headers = {}
3681+
3682+
if transport is None:
3683+
transport = client._http
3684+
if side_effect is not None:
3685+
transport.request.side_effect = side_effect
36353686
if timeout is None:
36363687
expected_timeout = self._get_default_timeout()
36373688
timeout_kwarg = {}
@@ -3689,6 +3740,7 @@ def _create_resumable_upload_session_helper(
36893740
**_get_default_headers(
36903741
client._connection.user_agent, x_upload_content_type=content_type
36913742
),
3743+
**client._extra_headers,
36923744
"x-upload-content-length": str(size),
36933745
"x-upload-content-type": content_type,
36943746
}
@@ -3750,6 +3802,28 @@ def test_create_resumable_upload_session_with_failure(self):
37503802
self.assertIn(message, exc_info.exception.message)
37513803
self.assertEqual(exc_info.exception.errors, [])
37523804

3805+
def test_create_resumable_upload_session_with_client(self):
3806+
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
3807+
response_headers = {"location": resumable_url}
3808+
transport = self._mock_transport(http.client.OK, response_headers)
3809+
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
3810+
client._connection.API_BASE_URL = "https://storage.googleapis.com"
3811+
client._extra_headers = {}
3812+
self._create_resumable_upload_session_helper(client=client)
3813+
3814+
def test_create_resumable_upload_session_with_client_custom_headers(self):
3815+
custom_headers = {
3816+
"x-goog-custom-audit-foo": "bar",
3817+
"x-goog-custom-audit-user": "baz",
3818+
}
3819+
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
3820+
response_headers = {"location": resumable_url}
3821+
transport = self._mock_transport(http.client.OK, response_headers)
3822+
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
3823+
client._connection.API_BASE_URL = "https://storage.googleapis.com"
3824+
client._extra_headers = custom_headers
3825+
self._create_resumable_upload_session_helper(client=client)
3826+
37533827
def test_get_iam_policy_defaults(self):
37543828
from google.cloud.storage.iam import STORAGE_OWNER_ROLE
37553829
from google.cloud.storage.iam import STORAGE_EDITOR_ROLE
@@ -5815,6 +5889,34 @@ def test_open(self):
58155889
with self.assertRaises(ValueError):
58165890
blob.open("w", ignore_flush=False)
58175891

5892+
def test_downloads_w_client_custom_headers(self):
5893+
import google.auth.credentials
5894+
from google.cloud.storage import Client
5895+
5896+
custom_headers = {
5897+
"x-goog-custom-audit-foo": "bar",
5898+
"x-goog-custom-audit-user": "baz",
5899+
}
5900+
credentials = mock.Mock(spec=google.auth.credentials.Credentials)
5901+
client = Client(
5902+
project="project", credentials=credentials, extra_headers=custom_headers
5903+
)
5904+
blob = self._make_one("blob-name", bucket=_Bucket(client))
5905+
file_obj = io.BytesIO()
5906+
5907+
downloads = {
5908+
client.download_blob_to_file: (blob, file_obj),
5909+
blob.download_to_file: (file_obj,),
5910+
blob.download_as_bytes: (),
5911+
}
5912+
for method, args in downloads.items():
5913+
with mock.patch.object(blob, "_do_download"):
5914+
method(*args)
5915+
blob._do_download.assert_called()
5916+
called_headers = blob._do_download.call_args.args[-4]
5917+
self.assertIsInstance(called_headers, dict)
5918+
self.assertDictContainsSubset(custom_headers, called_headers)
5919+
58185920

58195921
class Test__quote(unittest.TestCase):
58205922
@staticmethod

0 commit comments

Comments
 (0)