Skip to content

Commit cb61fdc

Browse files
committed
support static credentials
1 parent 83c3b3d commit cb61fdc

File tree

7 files changed

+241
-193
lines changed

7 files changed

+241
-193
lines changed

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,7 @@
2929
"enum-compat>=0.0.1",
3030
),
3131
options={"bdist_wheel": {"universal": True}},
32+
extras_require={
33+
"yc": ["yandexcloud", ],
34+
}
3235
)

ydb/_utilities.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
interceptor = None
1414

1515

16+
_grpcs_protocol = "grpcs://"
17+
_grpc_protocol = "grpc://"
18+
19+
1620
def wrap_result_in_future(result):
1721
f = futures.Future()
1822
f.set_result(result)
@@ -33,6 +37,32 @@ def x_ydb_sdk_build_info_header():
3337
return ("x-ydb-sdk-build-info", "ydb-python-sdk/" + ydb_version.VERSION)
3438

3539

40+
def is_secure_protocol(endpoint):
41+
return endpoint.startswith("grpcs://")
42+
43+
44+
def wrap_endpoint(endpoint):
45+
if endpoint.startswith(_grpcs_protocol):
46+
return endpoint[len(_grpcs_protocol) :]
47+
if endpoint.startswith(_grpc_protocol):
48+
return endpoint[len(_grpc_protocol) :]
49+
return endpoint
50+
51+
52+
def parse_connection_string(connection_string):
53+
cs = connection_string
54+
if not cs.startswith(_grpc_protocol) and not cs.startswith(_grpcs_protocol):
55+
# default is grpcs
56+
cs = _grpcs_protocol + cs
57+
58+
p = six.moves.urllib.parse.urlparse(connection_string)
59+
b = six.moves.urllib.parse.parse_qs(p.query)
60+
database = b.get("database", [])
61+
assert len(database) > 0
62+
63+
return p.scheme + "://" + p.netloc, database[0]
64+
65+
3666
# Decorator that ensures no exceptions are leaked from decorated async call
3767
def wrap_async_call_exceptions(f):
3868
@functools.wraps(f)

ydb/connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def _construct_metadata(driver_config, settings):
138138
if driver_config.database is not None:
139139
metadata.append((YDB_DATABASE_HEADER, driver_config.database))
140140

141-
if driver_config.credentials is not None:
141+
need_rpc_auth = getattr(settings, "need_rpc_auth", True)
142+
if driver_config.credentials is not None and need_rpc_auth:
142143
metadata.extend(driver_config.credentials.auth_metadata())
143144

144145
if settings is not None:

ydb/credentials.py

Lines changed: 182 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
# -*- coding: utf-8 -*-
22
import abc
33
import six
4-
from . import tracing
4+
from . import tracing, issues, connection
5+
from . import settings as settings_impl
6+
import threading
7+
from concurrent import futures
8+
import logging
9+
import time
10+
from ydb.public.api.protos import ydb_auth_pb2
11+
from ydb.public.api.grpc import ydb_auth_v1_pb2_grpc
12+
513

614
YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket"
15+
logger = logging.getLogger(__name__)
716

817

918
@six.add_metaclass(abc.ABCMeta)
@@ -26,6 +35,178 @@ def auth_metadata(self):
2635
pass
2736

2837

38+
class OneToManyValue(object):
39+
def __init__(self):
40+
self._value = None
41+
self._condition = threading.Condition()
42+
43+
def consume(self, timeout=3):
44+
with self._condition:
45+
if self._value is None:
46+
self._condition.wait(timeout=timeout)
47+
return self._value
48+
49+
def update(self, n_value):
50+
with self._condition:
51+
prev_value = self._value
52+
self._value = n_value
53+
if prev_value is None:
54+
self._condition.notify_all()
55+
56+
57+
class AtMostOneExecution(object):
58+
def __init__(self):
59+
self._can_schedule = True
60+
self._lock = threading.Lock()
61+
self._tp = futures.ThreadPoolExecutor(1)
62+
63+
def wrapped_execution(self, callback):
64+
try:
65+
callback()
66+
except Exception:
67+
pass
68+
69+
finally:
70+
self.cleanup()
71+
72+
def submit(self, callback):
73+
with self._lock:
74+
if self._can_schedule:
75+
self._tp.submit(self.wrapped_execution, callback)
76+
self._can_schedule = False
77+
78+
def cleanup(self):
79+
with self._lock:
80+
self._can_schedule = True
81+
82+
83+
@six.add_metaclass(abc.ABCMeta)
84+
class AbstractExpiringTokenCredentials(Credentials):
85+
def __init__(self, tracer=None):
86+
super(AbstractExpiringTokenCredentials, self).__init__(tracer)
87+
self._expires_in = 0
88+
self._refresh_in = 0
89+
self._hour = 60 * 60
90+
self._cached_token = OneToManyValue()
91+
self._tp = AtMostOneExecution()
92+
self.logger = logger.getChild(self.__class__.__name__)
93+
self.last_error = None
94+
self.extra_error_message = ""
95+
96+
@abc.abstractmethod
97+
def _make_token_request(self):
98+
pass
99+
100+
def _log_refresh_start(self, current_time):
101+
self.logger.debug("Start refresh token from metadata")
102+
if current_time > self._refresh_in:
103+
self.logger.info(
104+
"Cached token reached refresh_in deadline, current time %s, deadline %s",
105+
current_time,
106+
self._refresh_in,
107+
)
108+
109+
if current_time > self._expires_in and self._expires_in > 0:
110+
self.logger.error(
111+
"Cached token reached expires_in deadline, current time %s, deadline %s",
112+
current_time,
113+
self._expires_in,
114+
)
115+
116+
def _update_expiration_info(self, auth_metadata):
117+
self._expires_in = time.time() + min(
118+
self._hour, auth_metadata["expires_in"] / 2
119+
)
120+
self._refresh_in = time.time() + min(
121+
self._hour / 2, auth_metadata["expires_in"] / 4
122+
)
123+
124+
def _refresh(self):
125+
current_time = time.time()
126+
self._log_refresh_start(current_time)
127+
try:
128+
token_response = self._make_token_request()
129+
self._cached_token.update(token_response["access_token"])
130+
self._update_expiration_info(token_response)
131+
self.logger.info(
132+
"Token refresh successful. current_time %s, refresh_in %s",
133+
current_time,
134+
self._refresh_in,
135+
)
136+
137+
except (KeyboardInterrupt, SystemExit):
138+
return
139+
140+
except Exception as e:
141+
self.last_error = str(e)
142+
time.sleep(1)
143+
self._tp.submit(self._refresh)
144+
145+
@property
146+
@tracing.with_trace()
147+
def token(self):
148+
current_time = time.time()
149+
if current_time > self._refresh_in:
150+
tracing.trace(self.tracer, {"refresh": True})
151+
self._tp.submit(self._refresh)
152+
cached_token = self._cached_token.consume(timeout=15)
153+
tracing.trace(self.tracer, {"consumed": True})
154+
if cached_token is None:
155+
if self.last_error is None:
156+
raise issues.ConnectionError(
157+
"%s: timeout occurred while waiting for token.\n%s"
158+
% (
159+
self.__class__.__name__,
160+
self.extra_error_message,
161+
)
162+
)
163+
raise issues.ConnectionError(
164+
"%s: %s.\n%s"
165+
% (self.__class__.__name__, self.last_error, self.extra_error_message)
166+
)
167+
return cached_token
168+
169+
def auth_metadata(self):
170+
return [(YDB_AUTH_TICKET_HEADER, self.token)]
171+
172+
173+
def _wrap_static_credentials_response(rpc_state, response):
174+
issues._process_response(response.operation)
175+
result = ydb_auth_pb2.LoginResult()
176+
response.operation.result.Unpack(result)
177+
return result
178+
179+
180+
class StaticCredentials(AbstractExpiringTokenCredentials):
181+
def __init__(self, driver_config, user, password="", tracer=None):
182+
super(StaticCredentials, self).__init__(tracer)
183+
self.driver_config = driver_config
184+
self.user = user
185+
self.password = password
186+
self.request_timeout = 10
187+
188+
def _make_token_request(self):
189+
conn = connection.Connection.ready_factory(
190+
self.driver_config.endpoint, self.driver_config
191+
)
192+
assert conn is not None, (
193+
"Failed to establish connection in to %s" % self.driver_config.endpoint
194+
)
195+
try:
196+
result = conn(
197+
ydb_auth_pb2.LoginRequest(user=self.user, password=self.password),
198+
ydb_auth_v1_pb2_grpc.AuthServiceStub,
199+
"Login",
200+
_wrap_static_credentials_response,
201+
settings_impl.BaseRequestSettings()
202+
.with_timeout(self.request_timeout)
203+
.with_need_rpc_auth(False),
204+
)
205+
finally:
206+
conn.close()
207+
return {"expires_in": 30 * 60, "access_token": result.token}
208+
209+
29210
class AnonymousCredentials(Credentials):
30211
@staticmethod
31212
def auth_metadata():

ydb/driver.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,13 @@
44
import six
55
import os
66
import grpc
7+
from . import _utilities
78

89
if six.PY2:
910
Any = None
1011
else:
1112
from typing import Any # noqa
1213

13-
_grpcs_protocol = "grpcs://"
14-
_grpc_protocol = "grpc://"
15-
16-
17-
def is_secure_protocol(endpoint):
18-
return endpoint.startswith("grpcs://")
19-
20-
21-
def wrap_endpoint(endpoint):
22-
if endpoint.startswith(_grpcs_protocol):
23-
return endpoint[len(_grpcs_protocol) :]
24-
if endpoint.startswith(_grpc_protocol):
25-
return endpoint[len(_grpc_protocol) :]
26-
return endpoint
27-
28-
29-
def parse_connection_string(connection_string):
30-
cs = connection_string
31-
if not cs.startswith(_grpc_protocol) and not cs.startswith(_grpcs_protocol):
32-
# default is grpcs
33-
cs = _grpcs_protocol + cs
34-
35-
p = six.moves.urllib.parse.urlparse(connection_string)
36-
b = six.moves.urllib.parse.parse_qs(p.query)
37-
database = b.get("database", [])
38-
assert len(database) > 0
39-
40-
return p.scheme + "://" + p.netloc, database[0]
41-
4214

4315
class RPCCompression:
4416
"""Indicates the compression method to be used for an RPC."""
@@ -152,11 +124,11 @@ def __init__(
152124
self.database = database
153125
self.ca_cert = ca_cert
154126
self.channel_options = channel_options
155-
self.secure_channel = is_secure_protocol(endpoint)
156-
self.endpoint = wrap_endpoint(self.endpoint)
127+
self.secure_channel = _utilities.is_secure_protocol(endpoint)
128+
self.endpoint = _utilities.wrap_endpoint(self.endpoint)
157129
self.endpoints = []
158130
if endpoints is not None:
159-
self.endpoints = [wrap_endpoint(endp) for endp in endpoints]
131+
self.endpoints = [_utilities.wrap_endpoint(endp) for endp in endpoints]
160132
if auth_token is not None:
161133
credentials = credentials_impl.AuthTokenCredentials(auth_token)
162134
self.credentials = credentials
@@ -192,7 +164,7 @@ def default_from_endpoint_and_database(
192164
def default_from_connection_string(
193165
cls, connection_string, root_certificates=None, credentials=None, **kwargs
194166
):
195-
endpoint, database = parse_connection_string(connection_string)
167+
endpoint, database = _utilities.parse_connection_string(connection_string)
196168
return cls(
197169
endpoint,
198170
database,

0 commit comments

Comments
 (0)