Skip to content

Commit 184df95

Browse files
authored
Provide a UrlCredentialsProvider for assume role (#317)
1 parent 8ccf49c commit 184df95

File tree

2 files changed

+67
-5
lines changed

2 files changed

+67
-5
lines changed

tosfs/certification.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
"""It contains everything about certification via a file-based provider."""
1515

1616
import threading
17-
from datetime import datetime
17+
from datetime import datetime, timedelta
1818
from typing import Optional
1919
from xml.etree import ElementTree
2020

21+
import requests
22+
from tos.consts import ECS_DATE_FORMAT
2123
from tos.credential import Credentials, CredentialsProvider
2224

2325
from tosfs.core import logger
@@ -182,3 +184,55 @@ def _try_get_credentials(self) -> Optional[Credentials]:
182184
)
183185
return None
184186
return self.credentials
187+
188+
189+
class UrlCredentialsProvider(CredentialsProvider):
190+
"""The class provides the credentials from an url."""
191+
192+
def __init__(self, credential_url: str):
193+
"""Initialize the UrlCredentialsProvider."""
194+
if not credential_url:
195+
raise TosfsCertificationError("The credential_url param must not be empty.")
196+
self._lock = threading.Lock()
197+
self.expires: Optional[datetime] = None
198+
self.credentials = None
199+
self.credential_url = credential_url
200+
201+
def get_credentials(self) -> Credentials:
202+
"""Get the credentials from the url."""
203+
res = self._try_get_credentials()
204+
if res is not None:
205+
return res
206+
with self._lock:
207+
try:
208+
res = self._try_get_credentials()
209+
if res is not None:
210+
return res
211+
212+
res = requests.get(self.credential_url, timeout=30)
213+
res_body = res.json()
214+
self.credentials = Credentials(
215+
res_body.get("AccessKeyId"),
216+
res_body.get("SecretAccessKey"),
217+
res_body.get("SessionToken"),
218+
)
219+
self.expires = datetime.strptime(
220+
res_body.get("ExpiredTime"), ECS_DATE_FORMAT
221+
)
222+
return self.credentials
223+
except Exception as e:
224+
if self.expires is not None and (
225+
datetime.now().timestamp() < self.expires.timestamp()
226+
):
227+
return self.credentials
228+
raise TosfsCertificationError("Get token failed") from e
229+
230+
def _try_get_credentials(self) -> Optional[Credentials]:
231+
if self.expires is None or self.credentials is None:
232+
return None
233+
if (
234+
datetime.now().timestamp()
235+
> (self.expires - timedelta(minutes=10)).timestamp()
236+
):
237+
return None
238+
return self.credentials

tosfs/exceptions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,27 @@
1313
# limitations under the License.
1414

1515
"""It contains exceptions definition for the tosfs package."""
16+
from typing import Optional
1617

1718

1819
class TosfsError(Exception):
1920
"""Base class for all tosfs exceptions."""
2021

21-
def __init__(self, message: str):
22+
def __init__(self, msg: str, cause: Optional[Exception] = None):
2223
"""Initialize the base class for all exceptions in the tosfs package."""
23-
super().__init__(message)
24+
super().__init__(msg, cause)
25+
self.message = msg
26+
self.cause = cause
27+
28+
def __str__(self) -> str:
29+
"""Return the string representation of the exception."""
30+
error = {"message": self.message, "case": str(self.cause)}
31+
return str(error)
2432

2533

2634
class TosfsCertificationError(TosfsError):
2735
"""Exception class for certification related exception."""
2836

29-
def __init__(self, message: str):
37+
def __init__(self, message: str, cause: Optional[Exception] = None):
3038
"""Initialize the exception class for certification related exception."""
31-
super().__init__(message)
39+
super().__init__(message, cause)

0 commit comments

Comments
 (0)