|
1 | 1 | import requests
|
2 | 2 | import json
|
| 3 | +import base64 |
3 | 4 | from collections import namedtuple
|
4 | 5 | from requests.auth import HTTPBasicAuth
|
5 | 6 | from requests_toolbelt.utils import dump
|
| 7 | +from datetime import datetime, timezone |
6 | 8 |
|
7 | 9 |
|
8 | 10 | class DomoAPITransport:
|
@@ -71,23 +73,52 @@ def request(self, url, method, headers, params=None, body=None):
|
71 | 73 | 'params': params, 'data': body, 'stream': True}
|
72 | 74 | if self.request_timeout:
|
73 | 75 | request_args['timeout'] = self.request_timeout
|
74 |
| - response = requests.request(**request_args) |
75 |
| - if response.status_code == requests.codes.UNAUTHORIZED: |
| 76 | + |
| 77 | + # Expiration date should be in UTC |
| 78 | + if datetime.now(timezone.utc).timestamp() > self.token_expiration: |
| 79 | + self.logger.debug("Access token is expired") |
76 | 80 | self._renew_access_token()
|
77 | 81 | headers['Authorization'] = 'bearer ' + self.access_token
|
78 |
| - response = requests.request(**request_args) |
79 |
| - return response |
| 82 | + |
| 83 | + return requests.request(**request_args) |
80 | 84 |
|
81 | 85 | def _renew_access_token(self):
|
82 | 86 | self.logger.debug("Renewing Access Token")
|
83 | 87 | url = self.apiHost + '/oauth/token?grant_type=client_credentials'
|
84 | 88 | response = requests.post(url=url, auth=HTTPBasicAuth(self.clientId, self.clientSecret))
|
85 | 89 | if response.status_code == requests.codes.OK:
|
86 | 90 | self.access_token = response.json()['access_token']
|
| 91 | + self.token_expiration = self._extract_expiration(self.access_token) |
87 | 92 | else:
|
88 | 93 | self.logger.debug('Error retrieving access token: ' + self.dump_response(response))
|
89 | 94 | raise Exception("Error retrieving a Domo API Access Token: " + response.text)
|
90 | 95 |
|
| 96 | + def _extract_expiration(self, access_token): |
| 97 | + expiration_date = 0 |
| 98 | + try: |
| 99 | + decoded_payload_dict = self._decode_payload(access_token) |
| 100 | + |
| 101 | + if 'exp' in decoded_payload_dict.keys(): |
| 102 | + expiration_date = decoded_payload_dict['exp'] |
| 103 | + self.logger.debug('Token expiration: {}' |
| 104 | + .format(expiration_date)) |
| 105 | + except Exception as err: |
| 106 | + # If an Exception is raised, log and continue. expiration_date will |
| 107 | + # either be 0 or set to the value in the JWT. |
| 108 | + self.logger.debug('Ran into error parsing token for expiration. ' |
| 109 | + 'Setting expiration date to 0. ' |
| 110 | + '{}: {}'.format(type(err).__name__, err)) |
| 111 | + return expiration_date |
| 112 | + |
| 113 | + def _decode_payload(self, access_token): |
| 114 | + token_parts = access_token.split('.') |
| 115 | + |
| 116 | + # Padding required for the base64 library |
| 117 | + payload_bytes = bytes(token_parts[1], 'utf-8') + b'==' |
| 118 | + decoded_payload_bytes = base64.urlsafe_b64decode(payload_bytes) |
| 119 | + payload_string = decoded_payload_bytes.decode('utf-8') |
| 120 | + return json.loads(payload_string) |
| 121 | + |
91 | 122 | def dump_response(self, response):
|
92 | 123 | data = dump.dump_all(response)
|
93 | 124 | return str(data.decode('utf-8'))
|
|
0 commit comments