Skip to content
This repository has been archived by the owner on Nov 5, 2019. It is now read-only.

When refreshing credentials with an stream body, rewind the stream be… #174

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,11 @@ def new_request(uri, method='GET', body=None, headers=None,
else:
headers['user-agent'] = self.user_agent

body_stream_position = None
if all(getattr(body, stream_prop, None) for stream_prop in
('read', 'seek', 'tell')):
body_stream_position = body.tell()

resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)

Expand All @@ -567,6 +572,9 @@ def new_request(uri, method='GET', body=None, headers=None,
refresh_attempt + 1, max_refresh_attempts)
self._refresh(request_orig)
self.apply(headers)
if body_stream_position is not None:
body.seek(body_stream_position)

resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)

Expand Down
9 changes: 4 additions & 5 deletions tests/http_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,16 @@ def request(self, uri,
connection_type=None):
resp, content = self._iterable.pop(0)
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
# Read any underlying stream before sending the request.
body_stream_content = body.read() if getattr(body, 'read', None) else None
if content == 'echo_request_headers':
content = headers
elif content == 'echo_request_headers_as_json':
content = json.dumps(headers)
elif content == 'echo_request_body':
if hasattr(body, 'read'):
content = body.read()
else:
content = body
content = body if body_stream_content is None else body_stream_content
elif content == 'echo_request_uri':
content = uri
elif not isinstance(content, bytes):
raise TypeError("http content should be bytes: %r" % (content,))
raise TypeError('http content should be bytes: %r' % (content,))
return httplib2.Response(resp), content
43 changes: 38 additions & 5 deletions tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@
import unittest

from .http_mock import HttpMockSequence
import six

from oauth2client import file
from oauth2client import locked_file
from oauth2client import multistore_file
from oauth2client import util
from oauth2client.client import AccessTokenCredentials
from oauth2client.client import OAuth2Credentials
from six.moves import http_client
try:
# Python2
from future_builtins import oct
Expand Down Expand Up @@ -154,15 +157,17 @@ def test_token_refresh_store_expires_soon(self):
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': '401'}, b'Initial token expired'),
({'status': '401'}, b'Store token expired'),
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
({'status': '200'}, b'Valid response to original request')
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)},
b'Valid response to original request')
])

credentials.authorize(http)
http.request('https://example.com')
self.assertEquals(credentials.access_token, access_token)
self.assertEqual(credentials.access_token, access_token)

def test_token_refresh_good_store(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
Expand All @@ -178,6 +183,34 @@ def test_token_refresh_good_store(self):
credentials._refresh(lambda x: x)
self.assertEquals(credentials.access_token, 'bar')

def test_token_refresh_stream_body(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)

s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)

valid_access_token = '1/3w'
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)}, 'echo_request_body')
])

body = six.StringIO('streaming body')

credentials.authorize(http)
_, content = http.request('https://example.com', body=body)
self.assertEqual(content, 'streaming body')
self.assertEqual(credentials.access_token, valid_access_token)

def test_credentials_delete(self):
credentials = self.create_test_credentials()

Expand Down