1
1
# -*- coding: utf-8 -*-
2
2
import abc
3
3
import six
4
- from . import tracing
4
+ from . import tracing , issues , connection
5
+ from . import settings as settings_impl
6
+ import threading
7
+ from concurrent import futures
8
+ import logging
9
+ import time
10
+ from ydb .public .api .protos import ydb_auth_pb2
11
+ from ydb .public .api .grpc import ydb_auth_v1_pb2_grpc
12
+
5
13
6
14
YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket"
15
+ logger = logging .getLogger (__name__ )
7
16
8
17
9
18
@six .add_metaclass (abc .ABCMeta )
@@ -26,6 +35,178 @@ def auth_metadata(self):
26
35
pass
27
36
28
37
38
+ class OneToManyValue (object ):
39
+ def __init__ (self ):
40
+ self ._value = None
41
+ self ._condition = threading .Condition ()
42
+
43
+ def consume (self , timeout = 3 ):
44
+ with self ._condition :
45
+ if self ._value is None :
46
+ self ._condition .wait (timeout = timeout )
47
+ return self ._value
48
+
49
+ def update (self , n_value ):
50
+ with self ._condition :
51
+ prev_value = self ._value
52
+ self ._value = n_value
53
+ if prev_value is None :
54
+ self ._condition .notify_all ()
55
+
56
+
57
+ class AtMostOneExecution (object ):
58
+ def __init__ (self ):
59
+ self ._can_schedule = True
60
+ self ._lock = threading .Lock ()
61
+ self ._tp = futures .ThreadPoolExecutor (1 )
62
+
63
+ def wrapped_execution (self , callback ):
64
+ try :
65
+ callback ()
66
+ except Exception :
67
+ pass
68
+
69
+ finally :
70
+ self .cleanup ()
71
+
72
+ def submit (self , callback ):
73
+ with self ._lock :
74
+ if self ._can_schedule :
75
+ self ._tp .submit (self .wrapped_execution , callback )
76
+ self ._can_schedule = False
77
+
78
+ def cleanup (self ):
79
+ with self ._lock :
80
+ self ._can_schedule = True
81
+
82
+
83
+ @six .add_metaclass (abc .ABCMeta )
84
+ class AbstractExpiringTokenCredentials (Credentials ):
85
+ def __init__ (self , tracer = None ):
86
+ super (AbstractExpiringTokenCredentials , self ).__init__ (tracer )
87
+ self ._expires_in = 0
88
+ self ._refresh_in = 0
89
+ self ._hour = 60 * 60
90
+ self ._cached_token = OneToManyValue ()
91
+ self ._tp = AtMostOneExecution ()
92
+ self .logger = logger .getChild (self .__class__ .__name__ )
93
+ self .last_error = None
94
+ self .extra_error_message = ""
95
+
96
+ @abc .abstractmethod
97
+ def _make_token_request (self ):
98
+ pass
99
+
100
+ def _log_refresh_start (self , current_time ):
101
+ self .logger .debug ("Start refresh token from metadata" )
102
+ if current_time > self ._refresh_in :
103
+ self .logger .info (
104
+ "Cached token reached refresh_in deadline, current time %s, deadline %s" ,
105
+ current_time ,
106
+ self ._refresh_in ,
107
+ )
108
+
109
+ if current_time > self ._expires_in and self ._expires_in > 0 :
110
+ self .logger .error (
111
+ "Cached token reached expires_in deadline, current time %s, deadline %s" ,
112
+ current_time ,
113
+ self ._expires_in ,
114
+ )
115
+
116
+ def _update_expiration_info (self , auth_metadata ):
117
+ self ._expires_in = time .time () + min (
118
+ self ._hour , auth_metadata ["expires_in" ] / 2
119
+ )
120
+ self ._refresh_in = time .time () + min (
121
+ self ._hour / 2 , auth_metadata ["expires_in" ] / 4
122
+ )
123
+
124
+ def _refresh (self ):
125
+ current_time = time .time ()
126
+ self ._log_refresh_start (current_time )
127
+ try :
128
+ token_response = self ._make_token_request ()
129
+ self ._cached_token .update (token_response ["access_token" ])
130
+ self ._update_expiration_info (token_response )
131
+ self .logger .info (
132
+ "Token refresh successful. current_time %s, refresh_in %s" ,
133
+ current_time ,
134
+ self ._refresh_in ,
135
+ )
136
+
137
+ except (KeyboardInterrupt , SystemExit ):
138
+ return
139
+
140
+ except Exception as e :
141
+ self .last_error = str (e )
142
+ time .sleep (1 )
143
+ self ._tp .submit (self ._refresh )
144
+
145
+ @property
146
+ @tracing .with_trace ()
147
+ def token (self ):
148
+ current_time = time .time ()
149
+ if current_time > self ._refresh_in :
150
+ tracing .trace (self .tracer , {"refresh" : True })
151
+ self ._tp .submit (self ._refresh )
152
+ cached_token = self ._cached_token .consume (timeout = 15 )
153
+ tracing .trace (self .tracer , {"consumed" : True })
154
+ if cached_token is None :
155
+ if self .last_error is None :
156
+ raise issues .ConnectionError (
157
+ "%s: timeout occurred while waiting for token.\n %s"
158
+ % (
159
+ self .__class__ .__name__ ,
160
+ self .extra_error_message ,
161
+ )
162
+ )
163
+ raise issues .ConnectionError (
164
+ "%s: %s.\n %s"
165
+ % (self .__class__ .__name__ , self .last_error , self .extra_error_message )
166
+ )
167
+ return cached_token
168
+
169
+ def auth_metadata (self ):
170
+ return [(YDB_AUTH_TICKET_HEADER , self .token )]
171
+
172
+
173
+ def _wrap_static_credentials_response (rpc_state , response ):
174
+ issues ._process_response (response .operation )
175
+ result = ydb_auth_pb2 .LoginResult ()
176
+ response .operation .result .Unpack (result )
177
+ return result
178
+
179
+
180
+ class StaticCredentials (AbstractExpiringTokenCredentials ):
181
+ def __init__ (self , driver_config , user , password = "" , tracer = None ):
182
+ super (StaticCredentials , self ).__init__ (tracer )
183
+ self .driver_config = driver_config
184
+ self .user = user
185
+ self .password = password
186
+ self .request_timeout = 10
187
+
188
+ def _make_token_request (self ):
189
+ conn = connection .Connection .ready_factory (
190
+ self .driver_config .endpoint , self .driver_config
191
+ )
192
+ assert conn is not None , (
193
+ "Failed to establish connection in to %s" % self .driver_config .endpoint
194
+ )
195
+ try :
196
+ result = conn (
197
+ ydb_auth_pb2 .LoginRequest (user = self .user , password = self .password ),
198
+ ydb_auth_v1_pb2_grpc .AuthServiceStub ,
199
+ "Login" ,
200
+ _wrap_static_credentials_response ,
201
+ settings_impl .BaseRequestSettings ()
202
+ .with_timeout (self .request_timeout )
203
+ .with_need_rpc_auth (False ),
204
+ )
205
+ finally :
206
+ conn .close ()
207
+ return {"expires_in" : 30 * 60 , "access_token" : result .token }
208
+
209
+
29
210
class AnonymousCredentials (Credentials ):
30
211
@staticmethod
31
212
def auth_metadata ():
0 commit comments