@@ -2246,8 +2246,13 @@ def test__set_metadata_to_none(self):
2246
2246
def test__get_upload_arguments (self ):
2247
2247
name = "blob-name"
2248
2248
key = b"[pXw@,p@@AfBfrR3x-2b2SCHR,.?YwRO"
2249
+ custom_headers = {
2250
+ "x-goog-custom-audit-foo" : "bar" ,
2251
+ "x-goog-custom-audit-user" : "baz" ,
2252
+ }
2249
2253
client = mock .Mock (_connection = _Connection )
2250
2254
client ._connection .user_agent = "testing 1.2.3"
2255
+ client ._extra_headers = custom_headers
2251
2256
blob = self ._make_one (name , bucket = None , encryption_key = key )
2252
2257
blob .content_disposition = "inline"
2253
2258
@@ -2271,6 +2276,7 @@ def test__get_upload_arguments(self):
2271
2276
"X-Goog-Encryption-Algorithm" : "AES256" ,
2272
2277
"X-Goog-Encryption-Key" : header_key_value ,
2273
2278
"X-Goog-Encryption-Key-Sha256" : header_key_hash_value ,
2279
+ ** custom_headers ,
2274
2280
}
2275
2281
self .assertEqual (
2276
2282
headers ["X-Goog-API-Client" ],
@@ -2325,6 +2331,7 @@ def _do_multipart_success(
2325
2331
2326
2332
client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
2327
2333
client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2334
+ client ._extra_headers = {}
2328
2335
2329
2336
# Mock get_api_base_url_for_mtls function.
2330
2337
mtls_url = "https://foo.mtls"
@@ -2424,11 +2431,14 @@ def _do_multipart_success(
2424
2431
with patch .object (
2425
2432
_helpers , "_get_invocation_id" , return_value = GCCL_INVOCATION_TEST_CONST
2426
2433
):
2427
- headers = _get_default_headers (
2428
- client ._connection .user_agent ,
2429
- b'multipart/related; boundary="==0=="' ,
2430
- "application/xml" ,
2431
- )
2434
+ headers = {
2435
+ ** _get_default_headers (
2436
+ client ._connection .user_agent ,
2437
+ b'multipart/related; boundary="==0=="' ,
2438
+ "application/xml" ,
2439
+ ),
2440
+ ** client ._extra_headers ,
2441
+ }
2432
2442
client ._http .request .assert_called_once_with (
2433
2443
"POST" , upload_url , data = payload , headers = headers , timeout = expected_timeout
2434
2444
)
@@ -2520,6 +2530,19 @@ def test__do_multipart_upload_with_client(self, mock_get_boundary):
2520
2530
transport = self ._mock_transport (http .client .OK , {})
2521
2531
client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
2522
2532
client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2533
+ client ._extra_headers = {}
2534
+ self ._do_multipart_success (mock_get_boundary , client = client )
2535
+
2536
+ @mock .patch ("google.resumable_media._upload.get_boundary" , return_value = b"==0==" )
2537
+ def test__do_multipart_upload_with_client_custom_headers (self , mock_get_boundary ):
2538
+ custom_headers = {
2539
+ "x-goog-custom-audit-foo" : "bar" ,
2540
+ "x-goog-custom-audit-user" : "baz" ,
2541
+ }
2542
+ transport = self ._mock_transport (http .client .OK , {})
2543
+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
2544
+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2545
+ client ._extra_headers = custom_headers
2523
2546
self ._do_multipart_success (mock_get_boundary , client = client )
2524
2547
2525
2548
@mock .patch ("google.resumable_media._upload.get_boundary" , return_value = b"==0==" )
@@ -2597,6 +2620,7 @@ def _initiate_resumable_helper(
2597
2620
# Create some mock arguments and call the method under test.
2598
2621
client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
2599
2622
client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2623
+ client ._extra_headers = {}
2600
2624
2601
2625
# Mock get_api_base_url_for_mtls function.
2602
2626
mtls_url = "https://foo.mtls"
@@ -2677,13 +2701,15 @@ def _initiate_resumable_helper(
2677
2701
_helpers , "_get_invocation_id" , return_value = GCCL_INVOCATION_TEST_CONST
2678
2702
):
2679
2703
if extra_headers is None :
2680
- self .assertEqual (
2681
- upload ._headers ,
2682
- _get_default_headers (client ._connection .user_agent , content_type ),
2683
- )
2704
+ expected_headers = {
2705
+ ** _get_default_headers (client ._connection .user_agent , content_type ),
2706
+ ** client ._extra_headers ,
2707
+ }
2708
+ self .assertEqual (upload ._headers , expected_headers )
2684
2709
else :
2685
2710
expected_headers = {
2686
2711
** _get_default_headers (client ._connection .user_agent , content_type ),
2712
+ ** client ._extra_headers ,
2687
2713
** extra_headers ,
2688
2714
}
2689
2715
self .assertEqual (upload ._headers , expected_headers )
@@ -2730,9 +2756,12 @@ def _initiate_resumable_helper(
2730
2756
with patch .object (
2731
2757
_helpers , "_get_invocation_id" , return_value = GCCL_INVOCATION_TEST_CONST
2732
2758
):
2733
- expected_headers = _get_default_headers (
2734
- client ._connection .user_agent , x_upload_content_type = content_type
2735
- )
2759
+ expected_headers = {
2760
+ ** _get_default_headers (
2761
+ client ._connection .user_agent , x_upload_content_type = content_type
2762
+ ),
2763
+ ** client ._extra_headers ,
2764
+ }
2736
2765
if size is not None :
2737
2766
expected_headers ["x-upload-content-length" ] = str (size )
2738
2767
if extra_headers is not None :
@@ -2824,6 +2853,21 @@ def test__initiate_resumable_upload_with_client(self):
2824
2853
2825
2854
client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
2826
2855
client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2856
+ client ._extra_headers = {}
2857
+ self ._initiate_resumable_helper (client = client )
2858
+
2859
+ def test__initiate_resumable_upload_with_client_custom_headers (self ):
2860
+ custom_headers = {
2861
+ "x-goog-custom-audit-foo" : "bar" ,
2862
+ "x-goog-custom-audit-user" : "baz" ,
2863
+ }
2864
+ resumable_url = "http://test.invalid?upload_id=hey-you"
2865
+ response_headers = {"location" : resumable_url }
2866
+ transport = self ._mock_transport (http .client .OK , response_headers )
2867
+
2868
+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
2869
+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
2870
+ client ._extra_headers = custom_headers
2827
2871
self ._initiate_resumable_helper (client = client )
2828
2872
2829
2873
def _make_resumable_transport (
@@ -3000,6 +3044,7 @@ def _do_resumable_helper(
3000
3044
client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
3001
3045
client ._connection .API_BASE_URL = "https://storage.googleapis.com"
3002
3046
client ._connection .user_agent = USER_AGENT
3047
+ client ._extra_headers = {}
3003
3048
stream = io .BytesIO (data )
3004
3049
3005
3050
bucket = _Bucket (name = "yesterday" )
@@ -3612,26 +3657,32 @@ def _create_resumable_upload_session_helper(
3612
3657
if_metageneration_match = None ,
3613
3658
if_metageneration_not_match = None ,
3614
3659
retry = None ,
3660
+ client = None ,
3615
3661
):
3616
3662
bucket = _Bucket (name = "alex-trebek" )
3617
3663
blob = self ._make_one ("blob-name" , bucket = bucket )
3618
3664
chunk_size = 99 * blob ._CHUNK_SIZE_MULTIPLE
3619
3665
blob .chunk_size = chunk_size
3620
-
3621
- # Create mocks to be checked for doing transport.
3622
3666
resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
3623
- response_headers = {"location" : resumable_url }
3624
- transport = self ._mock_transport (http .client .OK , response_headers )
3625
- if side_effect is not None :
3626
- transport .request .side_effect = side_effect
3627
-
3628
- # Create some mock arguments and call the method under test.
3629
3667
content_type = "text/plain"
3630
3668
size = 10000
3631
- client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
3632
- client ._connection .API_BASE_URL = "https://storage.googleapis.com"
3633
- client ._connection .user_agent = "testing 1.2.3"
3669
+ transport = None
3634
3670
3671
+ if not client :
3672
+ # Create mocks to be checked for doing transport.
3673
+ response_headers = {"location" : resumable_url }
3674
+ transport = self ._mock_transport (http .client .OK , response_headers )
3675
+
3676
+ # Create some mock arguments and call the method under test.
3677
+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
3678
+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
3679
+ client ._connection .user_agent = "testing 1.2.3"
3680
+ client ._extra_headers = {}
3681
+
3682
+ if transport is None :
3683
+ transport = client ._http
3684
+ if side_effect is not None :
3685
+ transport .request .side_effect = side_effect
3635
3686
if timeout is None :
3636
3687
expected_timeout = self ._get_default_timeout ()
3637
3688
timeout_kwarg = {}
@@ -3689,6 +3740,7 @@ def _create_resumable_upload_session_helper(
3689
3740
** _get_default_headers (
3690
3741
client ._connection .user_agent , x_upload_content_type = content_type
3691
3742
),
3743
+ ** client ._extra_headers ,
3692
3744
"x-upload-content-length" : str (size ),
3693
3745
"x-upload-content-type" : content_type ,
3694
3746
}
@@ -3750,6 +3802,28 @@ def test_create_resumable_upload_session_with_failure(self):
3750
3802
self .assertIn (message , exc_info .exception .message )
3751
3803
self .assertEqual (exc_info .exception .errors , [])
3752
3804
3805
+ def test_create_resumable_upload_session_with_client (self ):
3806
+ resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
3807
+ response_headers = {"location" : resumable_url }
3808
+ transport = self ._mock_transport (http .client .OK , response_headers )
3809
+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
3810
+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
3811
+ client ._extra_headers = {}
3812
+ self ._create_resumable_upload_session_helper (client = client )
3813
+
3814
+ def test_create_resumable_upload_session_with_client_custom_headers (self ):
3815
+ custom_headers = {
3816
+ "x-goog-custom-audit-foo" : "bar" ,
3817
+ "x-goog-custom-audit-user" : "baz" ,
3818
+ }
3819
+ resumable_url = "http://test.invalid?upload_id=clean-up-everybody"
3820
+ response_headers = {"location" : resumable_url }
3821
+ transport = self ._mock_transport (http .client .OK , response_headers )
3822
+ client = mock .Mock (_http = transport , _connection = _Connection , spec = ["_http" ])
3823
+ client ._connection .API_BASE_URL = "https://storage.googleapis.com"
3824
+ client ._extra_headers = custom_headers
3825
+ self ._create_resumable_upload_session_helper (client = client )
3826
+
3753
3827
def test_get_iam_policy_defaults (self ):
3754
3828
from google .cloud .storage .iam import STORAGE_OWNER_ROLE
3755
3829
from google .cloud .storage .iam import STORAGE_EDITOR_ROLE
@@ -5815,6 +5889,34 @@ def test_open(self):
5815
5889
with self .assertRaises (ValueError ):
5816
5890
blob .open ("w" , ignore_flush = False )
5817
5891
5892
+ def test_downloads_w_client_custom_headers (self ):
5893
+ import google .auth .credentials
5894
+ from google .cloud .storage import Client
5895
+
5896
+ custom_headers = {
5897
+ "x-goog-custom-audit-foo" : "bar" ,
5898
+ "x-goog-custom-audit-user" : "baz" ,
5899
+ }
5900
+ credentials = mock .Mock (spec = google .auth .credentials .Credentials )
5901
+ client = Client (
5902
+ project = "project" , credentials = credentials , extra_headers = custom_headers
5903
+ )
5904
+ blob = self ._make_one ("blob-name" , bucket = _Bucket (client ))
5905
+ file_obj = io .BytesIO ()
5906
+
5907
+ downloads = {
5908
+ client .download_blob_to_file : (blob , file_obj ),
5909
+ blob .download_to_file : (file_obj ,),
5910
+ blob .download_as_bytes : (),
5911
+ }
5912
+ for method , args in downloads .items ():
5913
+ with mock .patch .object (blob , "_do_download" ):
5914
+ method (* args )
5915
+ blob ._do_download .assert_called ()
5916
+ called_headers = blob ._do_download .call_args .args [- 4 ]
5917
+ self .assertIsInstance (called_headers , dict )
5918
+ self .assertDictContainsSubset (custom_headers , called_headers )
5919
+
5818
5920
5819
5921
class Test__quote (unittest .TestCase ):
5820
5922
@staticmethod
0 commit comments