Skip to content

Commit 792d8d7

Browse files
authored
Merge pull request #19 from ydb-platform/support-node-id
allow to refer endpoints by node id
2 parents 862505a + 8cb1d72 commit 792d8d7

File tree

8 files changed

+75
-20
lines changed

8 files changed

+75
-20
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 2.2.0 ##
2+
3+
* allow to refer endpoints by node id
4+
15
## 2.1.0 ##
26

37
* add compression support to ydb sdk

ydb/_session_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def initialize_session(rpc_state, response_pb, session_state, session):
237237
issues._process_response(response_pb.operation)
238238
message = _apis.ydb_table.CreateSessionResult()
239239
response_pb.operation.result.Unpack(message)
240-
session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint)
240+
session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint_key)
241241
return session
242242

243243

ydb/aio/connection.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
YDB_DATABASE_HEADER,
1919
YDB_TRACE_ID_HEADER,
2020
YDB_REQUEST_TYPE_HEADER,
21+
EndpointKey,
2122
)
2223
from ydb.driver import DriverConfig
2324
from ydb.settings import BaseRequestSettings
@@ -71,8 +72,8 @@ class _RpcState(RpcState):
7172
"_trailing_metadata",
7273
)
7374

74-
def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str):
75-
super().__init__(stub_instance, rpc_name, endpoint)
75+
def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str, endpoint_key):
76+
super().__init__(stub_instance, rpc_name, endpoint, endpoint_key)
7677

7778
async def __call__(self, *args, **kwargs):
7879
resp = self.rpc(*args, **kwargs)
@@ -105,6 +106,8 @@ class Connection:
105106
"lock",
106107
"calls",
107108
"closing",
109+
"endpoint_key",
110+
"node_id",
108111
)
109112

110113
def __init__(
@@ -115,6 +118,10 @@ def __init__(
115118
):
116119
global _stubs_list
117120
self.endpoint = endpoint
121+
self.endpoint_key = EndpointKey(
122+
self.endpoint, getattr(endpoint_options, "node_id", None)
123+
)
124+
self.node_id = getattr(endpoint_options, "node_id", None)
118125
self._channel = channel_factory(
119126
self.endpoint, driver_config, grpc.aio, endpoint_options=endpoint_options
120127
)
@@ -141,7 +148,9 @@ async def _prepare_call(
141148
)
142149
_set_server_timeouts(request, settings, timeout)
143150
self._prepare_stub_instance(stub)
144-
rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint)
151+
rpc_state = _RpcState(
152+
self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key
153+
)
145154
logger.debug("%s: creating call state", rpc_state)
146155

147156
if self.closing:

ydb/aio/pool.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,16 @@ async def get(self, preferred_endpoint=None, fast_fail=False, wait_timeout=10):
3030
else:
3131
await asyncio.wait_for(self._event.wait(), timeout=wait_timeout)
3232

33-
if preferred_endpoint is not None and preferred_endpoint in self.connections:
33+
if (
34+
preferred_endpoint is not None
35+
and preferred_endpoint.node_id in self.connections_by_node_id
36+
):
37+
return self.connections_by_node_id[preferred_endpoint.node_id]
38+
39+
if (
40+
preferred_endpoint is not None
41+
and preferred_endpoint.endpoint in self.connections
42+
):
3443
return self.connections[preferred_endpoint]
3544

3645
for conn_lst in self.conn_lst_order:
@@ -52,6 +61,8 @@ def add(self, connection, preferred=False):
5261

5362
if preferred:
5463
self.preferred[connection.endpoint] = connection
64+
65+
self.connections_by_node_id[connection.node_id] = connection
5566
self.connections[connection.endpoint] = connection
5667

5768
self._event.set()
@@ -66,6 +77,7 @@ def complete_discovery(self, error):
6677
self._fast_fail_event.set()
6778

6879
def remove(self, connection):
80+
self.connections_by_node_id.pop(connection.node_id, None)
6981
self.preferred.pop(connection.endpoint, None)
7082
self.connections.pop(connection.endpoint, None)
7183
self.outdated.pop(connection.endpoint, None)

ydb/connection.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,11 @@ def _get_request_timeout(settings):
163163

164164

165165
class EndpointOptions(object):
166-
__slots__ = ("ssl_target_name_override",)
166+
__slots__ = ("ssl_target_name_override", "node_id")
167167

168-
def __init__(self, ssl_target_name_override=None):
168+
def __init__(self, ssl_target_name_override=None, node_id=None):
169169
self.ssl_target_name_override = ssl_target_name_override
170+
self.node_id = node_id
170171

171172

172173
def _construct_channel_options(driver_config, endpoint_options=None):
@@ -223,16 +224,18 @@ class _RpcState(object):
223224
"endpoint",
224225
"rendezvous",
225226
"metadata_kv",
227+
"endpoint_key",
226228
)
227229

228-
def __init__(self, stub_instance, rpc_name, endpoint):
230+
def __init__(self, stub_instance, rpc_name, endpoint, endpoint_key):
229231
"""Stores all RPC related data"""
230232
self.rpc_name = rpc_name
231233
self.rpc = getattr(stub_instance, rpc_name)
232234
self.request_id = uuid.uuid4()
233235
self.endpoint = endpoint
234236
self.rendezvous = None
235237
self.metadata_kv = None
238+
self.endpoint_key = endpoint_key
236239

237240
def __str__(self):
238241
return "RpcState(%s, %s, %s)" % (self.rpc_name, self.request_id, self.endpoint)
@@ -318,6 +321,14 @@ def channel_factory(
318321
)
319322

320323

324+
class EndpointKey(object):
325+
__slots__ = ("endpoint", "node_id")
326+
327+
def __init__(self, endpoint, node_id):
328+
self.endpoint = endpoint
329+
self.node_id = node_id
330+
331+
321332
class Connection(object):
322333
__slots__ = (
323334
"endpoint",
@@ -330,6 +341,8 @@ class Connection(object):
330341
"lock",
331342
"calls",
332343
"closing",
344+
"endpoint_key",
345+
"node_id",
333346
)
334347

335348
def __init__(self, endpoint, driver_config=None, endpoint_options=None):
@@ -341,6 +354,10 @@ def __init__(self, endpoint, driver_config=None, endpoint_options=None):
341354
"""
342355
global _stubs_list
343356
self.endpoint = endpoint
357+
self.node_id = getattr(endpoint_options, "node_id", None)
358+
self.endpoint_key = EndpointKey(
359+
endpoint, getattr(endpoint_options, "node_id", None)
360+
)
344361
self._channel = channel_factory(
345362
self.endpoint, driver_config, endpoint_options=endpoint_options
346363
)
@@ -368,7 +385,9 @@ def _prepare_call(self, stub, rpc_name, request, settings):
368385
)
369386
_set_server_timeouts(request, settings, timeout)
370387
self._prepare_stub_instance(stub)
371-
rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint)
388+
rpc_state = _RpcState(
389+
self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key
390+
)
372391
logger.debug("%s: creating call state", rpc_state)
373392
with self.lock:
374393
if self.closing:

ydb/pool.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self, use_all_nodes=False, tracer=tracing.Tracer(None)):
1919
self.tracer = tracer
2020
self.lock = threading.RLock()
2121
self.connections = collections.OrderedDict()
22+
self.connections_by_node_id = collections.OrderedDict()
2223
self.outdated = collections.OrderedDict()
2324
self.subscriptions = set()
2425
self.preferred = collections.OrderedDict()
@@ -39,6 +40,8 @@ def add(self, connection, preferred=False):
3940
with self.lock:
4041
if preferred:
4142
self.preferred[connection.endpoint] = connection
43+
44+
self.connections_by_node_id[connection.node_id] = connection
4245
self.connections[connection.endpoint] = connection
4346
subscriptions = list(self.subscriptions)
4447
self.subscriptions.clear()
@@ -128,9 +131,14 @@ def get(self, preferred_endpoint=None):
128131
with self.lock:
129132
if (
130133
preferred_endpoint is not None
131-
and preferred_endpoint in self.connections
134+
and preferred_endpoint.node_id in self.connections_by_node_id
135+
):
136+
return self.connections_by_node_id[preferred_endpoint.node_id]
137+
138+
if (
139+
preferred_endpoint is not None
140+
and preferred_endpoint.endpoint in self.connections
132141
):
133-
tracing.trace(self.tracer, {"found_preferred_endpoint": True})
134142
return self.connections[preferred_endpoint]
135143

136144
for conn_lst in self.conn_lst_order:
@@ -146,6 +154,7 @@ def get(self, preferred_endpoint=None):
146154

147155
def remove(self, connection):
148156
with self.lock:
157+
self.connections_by_node_id.pop(connection.node_id, None)
149158
self.preferred.pop(connection.endpoint, None)
150159
self.connections.pop(connection.endpoint, None)
151160
self.outdated.pop(connection.endpoint, None)

ydb/resolver.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class EndpointInfo(object):
1919
"ipv4_addrs",
2020
"ipv6_addrs",
2121
"ssl_target_name_override",
22+
"node_id",
2223
)
2324

2425
def __init__(self, endpoint_info):
@@ -30,19 +31,20 @@ def __init__(self, endpoint_info):
3031
self.ipv4_addrs = tuple(endpoint_info.ip_v4)
3132
self.ipv6_addrs = tuple(endpoint_info.ip_v6)
3233
self.ssl_target_name_override = endpoint_info.ssl_target_name_override
34+
self.node_id = endpoint_info.node_id
3335

3436
def endpoints_with_options(self):
37+
ssl_target_name_override = None
3538
if self.ssl:
3639
if self.ssl_target_name_override:
37-
endpoint_options = conn_impl.EndpointOptions(
38-
self.ssl_target_name_override
39-
)
40+
ssl_target_name_override = self.ssl_target_name_override
4041
elif self.ipv6_addrs or self.ipv4_addrs:
41-
endpoint_options = conn_impl.EndpointOptions(self.address)
42-
else:
43-
endpoint_options = None
44-
else:
45-
endpoint_options = None
42+
ssl_target_name_override = self.address
43+
44+
endpoint_options = conn_impl.EndpointOptions(
45+
ssl_target_name_override=ssl_target_name_override, node_id=self.node_id
46+
)
47+
4648
if self.ipv6_addrs or self.ipv4_addrs:
4749
for ipv6addr in self.ipv6_addrs:
4850
yield ("ipv6:[%s]:%s" % (ipv6addr, self.port), endpoint_options)

ydb/ydb_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
VERSION = "2.1.0"
1+
VERSION = "2.2.0"

0 commit comments

Comments
 (0)