Skip to content

Commit aa78b91

Browse files
committed
refactoring to support static credentials
1 parent 83c3b3d commit aa78b91

File tree

2 files changed

+147
-159
lines changed

2 files changed

+147
-159
lines changed

ydb/credentials.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# -*- coding: utf-8 -*-
22
import abc
33
import six
4-
from . import tracing
4+
from . import tracing, issues
5+
import threading
6+
from concurrent import futures
7+
import logging
8+
import time
59

610
YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket"
11+
logger = logging.getLogger(__name__)
712

813

914
@six.add_metaclass(abc.ABCMeta)
@@ -26,6 +31,139 @@ def auth_metadata(self):
2631
pass
2732

2833

34+
class OneToManyValue(object):
35+
def __init__(self):
36+
self._value = None
37+
self._condition = threading.Condition()
38+
39+
def consume(self, timeout=3):
40+
with self._condition:
41+
if self._value is None:
42+
self._condition.wait(timeout=timeout)
43+
return self._value
44+
45+
def update(self, n_value):
46+
with self._condition:
47+
prev_value = self._value
48+
self._value = n_value
49+
if prev_value is None:
50+
self._condition.notify_all()
51+
52+
53+
class AtMostOneExecution(object):
54+
def __init__(self):
55+
self._can_schedule = True
56+
self._lock = threading.Lock()
57+
self._tp = futures.ThreadPoolExecutor(1)
58+
59+
def wrapped_execution(self, callback):
60+
try:
61+
callback()
62+
except Exception:
63+
pass
64+
65+
finally:
66+
self.cleanup()
67+
68+
def submit(self, callback):
69+
with self._lock:
70+
if self._can_schedule:
71+
self._tp.submit(self.wrapped_execution, callback)
72+
self._can_schedule = False
73+
74+
def cleanup(self):
75+
with self._lock:
76+
self._can_schedule = True
77+
78+
79+
@six.add_metaclass(abc.ABCMeta)
80+
class AbstractExpiringTokenCredentials(Credentials):
81+
def __init__(self, tracer=None):
82+
super(AbstractExpiringTokenCredentials, self).__init__(tracer)
83+
self._expires_in = 0
84+
self._refresh_in = 0
85+
self._hour = 60 * 60
86+
self._cached_token = OneToManyValue()
87+
self._tp = AtMostOneExecution()
88+
self.logger = logger.getChild(self.__class__.__name__)
89+
self.last_error = None
90+
self.extra_error_message = ""
91+
92+
@abc.abstractmethod
93+
def _make_token_request(self):
94+
pass
95+
96+
def _log_refresh_start(self, current_time):
97+
self.logger.debug("Start refresh token from metadata")
98+
if current_time > self._refresh_in:
99+
self.logger.info(
100+
"Cached token reached refresh_in deadline, current time %s, deadline %s",
101+
current_time,
102+
self._refresh_in,
103+
)
104+
105+
if current_time > self._expires_in and self._expires_in > 0:
106+
self.logger.error(
107+
"Cached token reached expires_in deadline, current time %s, deadline %s",
108+
current_time,
109+
self._expires_in,
110+
)
111+
112+
def _update_expiration_info(self, auth_metadata):
113+
self._expires_in = time.time() + min(
114+
self._hour, auth_metadata["expires_in"] / 2
115+
)
116+
self._refresh_in = time.time() + min(
117+
self._hour / 2, auth_metadata["expires_in"] / 4
118+
)
119+
120+
def _refresh(self):
121+
current_time = time.time()
122+
self._log_refresh_start(current_time)
123+
try:
124+
token_response = self._make_token_request()
125+
self._cached_token.update(token_response["access_token"])
126+
self._update_expiration_info(token_response)
127+
self.logger.info(
128+
"Token refresh successful. current_time %s, refresh_in %s",
129+
current_time,
130+
self._refresh_in,
131+
)
132+
133+
except (KeyboardInterrupt, SystemExit):
134+
return
135+
136+
except Exception as e:
137+
self.last_error = str(e)
138+
time.sleep(1)
139+
self._tp.submit(self._refresh)
140+
141+
@property
142+
@tracing.with_trace()
143+
def token(self):
144+
current_time = time.time()
145+
if current_time > self._refresh_in:
146+
tracing.trace(self.tracer, {"refresh": True})
147+
self._tp.submit(self._refresh)
148+
cached_token = self._cached_token.consume(timeout=3)
149+
tracing.trace(self.tracer, {"consumed": True})
150+
if cached_token is None:
151+
if self.last_error is None:
152+
raise issues.ConnectionError(
153+
"%s: timeout occurred while waiting for token.\n%s"
154+
% self.__class__.__name__,
155+
self.extra_error_message,
156+
)
157+
raise issues.ConnectionError(
158+
"%s: %s.\n%s"
159+
% (self.__class__.__name__, self.last_error, self.extra_error_message)
160+
)
161+
return cached_token
162+
163+
def auth_metadata(self):
164+
return [(YDB_AUTH_TICKET_HEADER, self.token)]
165+
166+
29167
class AnonymousCredentials(Credentials):
30168
@staticmethod
31169
def auth_metadata():

ydb/iam/auth.py

Lines changed: 8 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,7 @@
66
import six
77
from datetime import datetime
88
import json
9-
import threading
10-
from concurrent import futures
119
import os
12-
import logging
13-
from ydb import issues
14-
15-
logger = logging.getLogger(__name__)
1610

1711
try:
1812
from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc
@@ -51,158 +45,18 @@ def get_jwt(account_id, access_key_id, private_key, jwt_expiration_timeout):
5145
)
5246

5347

54-
class OneToManyValue(object):
55-
def __init__(self):
56-
self._value = None
57-
self._condition = threading.Condition()
58-
59-
def consume(self, timeout=3):
60-
with self._condition:
61-
if self._value is None:
62-
self._condition.wait(timeout=timeout)
63-
return self._value
64-
65-
def update(self, n_value):
66-
with self._condition:
67-
prev_value = self._value
68-
self._value = n_value
69-
if prev_value is None:
70-
self._condition.notify_all()
71-
72-
73-
class AtMostOneExecution(object):
74-
def __init__(self):
75-
self._can_schedule = True
76-
self._lock = threading.Lock()
77-
self._tp = futures.ThreadPoolExecutor(1)
78-
79-
def wrapped_execution(self, callback):
80-
try:
81-
callback()
82-
except Exception:
83-
pass
84-
85-
finally:
86-
self.cleanup()
87-
88-
def submit(self, callback):
89-
with self._lock:
90-
if self._can_schedule:
91-
self._tp.submit(self.wrapped_execution, callback)
92-
self._can_schedule = False
93-
94-
def cleanup(self):
95-
with self._lock:
96-
self._can_schedule = True
97-
98-
9948
@six.add_metaclass(abc.ABCMeta)
100-
class IamTokenCredentials(credentials.Credentials):
101-
def __init__(self, tracer=None):
102-
super(IamTokenCredentials, self).__init__(tracer)
103-
self._expires_in = 0
104-
self._refresh_in = 0
105-
self._hour = 60 * 60
106-
self._iam_token = OneToManyValue()
107-
self._tp = AtMostOneExecution()
108-
self.logger = logger.getChild(self.__class__.__name__)
109-
self.last_error = None
110-
self.extra_error_message = ""
111-
112-
@abc.abstractmethod
113-
def _get_iam_token(self):
114-
pass
115-
116-
def _log_refresh_start(self, current_time):
117-
self.logger.debug("Start refresh token from metadata")
118-
if current_time > self._refresh_in:
119-
self.logger.info(
120-
"Cached token reached refresh_in deadline, current time %s, deadline %s",
121-
current_time,
122-
self._refresh_in,
123-
)
124-
125-
if current_time > self._expires_in and self._expires_in > 0:
126-
self.logger.error(
127-
"Cached token reached expires_in deadline, current time %s, deadline %s",
128-
current_time,
129-
self._expires_in,
130-
)
131-
132-
def _update_expiration_info(self, auth_metadata):
133-
self._expires_in = time.time() + min(
134-
self._hour, auth_metadata["expires_in"] / 2
135-
)
136-
self._refresh_in = time.time() + min(
137-
self._hour / 2, auth_metadata["expires_in"] / 4
138-
)
139-
140-
def _refresh(self):
141-
current_time = time.time()
142-
self._log_refresh_start(current_time)
143-
try:
144-
auth_metadata = self._get_iam_token()
145-
self._iam_token.update(auth_metadata["access_token"])
146-
self._update_expiration_info(auth_metadata)
147-
self.logger.info(
148-
"Token refresh successful. current_time %s, refresh_in %s",
149-
current_time,
150-
self._refresh_in,
151-
)
152-
153-
except (KeyboardInterrupt, SystemExit):
154-
return
155-
156-
except Exception as e:
157-
self.last_error = str(e)
158-
time.sleep(1)
159-
self._tp.submit(self._refresh)
160-
161-
@property
162-
@tracing.with_trace()
163-
def iam_token(self):
164-
current_time = time.time()
165-
if current_time > self._refresh_in:
166-
tracing.trace(self.tracer, {"refresh": True})
167-
self._tp.submit(self._refresh)
168-
iam_token = self._iam_token.consume(timeout=3)
169-
tracing.trace(self.tracer, {"consumed": True})
170-
if iam_token is None:
171-
if self.last_error is None:
172-
raise issues.ConnectionError(
173-
"%s: timeout occurred while waiting for token.\n%s"
174-
% self.__class__.__name__,
175-
self.extra_error_message,
176-
)
177-
raise issues.ConnectionError(
178-
"%s: %s.\n%s"
179-
% (self.__class__.__name__, self.last_error, self.extra_error_message)
180-
)
181-
return iam_token
182-
183-
def auth_metadata(self):
184-
return [(credentials.YDB_AUTH_TICKET_HEADER, self.iam_token)]
185-
186-
187-
@six.add_metaclass(abc.ABCMeta)
188-
class TokenServiceCredentials(IamTokenCredentials):
49+
class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials):
18950
def __init__(self, iam_endpoint=None, iam_channel_credentials=None, tracer=None):
19051
super(TokenServiceCredentials, self).__init__(tracer)
52+
assert iam_token_service_pb2_grpc is not None, "run pip install==ydb[yc] to use service account credentials"
53+
self._get_token_request_timeout = 10
19154
self._iam_endpoint = (
19255
"iam.api.cloud.yandex.net:443" if iam_endpoint is None else iam_endpoint
19356
)
19457
self._iam_channel_credentials = (
19558
{} if iam_channel_credentials is None else iam_channel_credentials
19659
)
197-
self._get_token_request_timeout = 10
198-
if (
199-
iam_token_service_pb2_grpc is None
200-
or jwt is None
201-
or iam_token_service_pb2 is None
202-
):
203-
raise RuntimeError(
204-
"Install jwt & yandex python cloud library to use service account credentials provider"
205-
)
20660

20761
def _channel_factory(self):
20862
return grpc.secure_channel(
@@ -215,7 +69,7 @@ def _get_token_request(self):
21569
pass
21670

21771
@tracing.with_trace()
218-
def _get_iam_token(self):
72+
def _make_token_request(self):
21973
with self._channel_factory() as channel:
22074
tracing.trace(self.tracer, {"iam_token.from_service": True})
22175
stub = iam_token_service_pb2_grpc.IamTokenServiceStub(channel)
@@ -296,26 +150,22 @@ def _get_token_request(self):
296150
)
297151

298152

299-
class MetadataUrlCredentials(IamTokenCredentials):
153+
class MetadataUrlCredentials(credentials.AbstractExpiringTokenCredentials):
300154
def __init__(self, metadata_url=None, tracer=None):
301155
"""
302-
303156
:param metadata_url: Metadata url
304157
:param ydb.Tracer tracer: ydb tracer
305158
"""
306159
super(MetadataUrlCredentials, self).__init__(tracer)
307-
if requests is None:
308-
raise RuntimeError(
309-
"Install requests library to use metadata credentials provider"
310-
)
160+
assert requests is not None, "Install requests library to use metadata credentials provider"
161+
self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud."
311162
self._metadata_url = (
312163
DEFAULT_METADATA_URL if metadata_url is None else metadata_url
313164
)
314165
self._tp.submit(self._refresh)
315-
self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud."
316166

317167
@tracing.with_trace()
318-
def _get_iam_token(self):
168+
def _make_token_request(self):
319169
response = requests.get(
320170
self._metadata_url, headers={"Metadata-Flavor": "Google"}, timeout=3
321171
)

0 commit comments

Comments
 (0)