Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit c7e12e7

Browse files
committed
Use smart stub in clients
1 parent 81fa7db commit c7e12e7

File tree

3 files changed

+32
-51
lines changed

3 files changed

+32
-51
lines changed

packages/jumpstarter/jumpstarter/client/client.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import grpc
77
from anyio.from_thread import BlockingPortal
88
from google.protobuf import empty_pb2
9-
from jumpstarter_protocol import jumpstarter_pb2_grpc
109

11-
from .grpc import SmartExporterServiceStub
10+
from .grpc import SmartExporterStub
1211
from jumpstarter.client import DriverClient
1312
from jumpstarter.common.importlib import import_class
1413

@@ -34,30 +33,24 @@ async def client_from_channel(
3433
reports = {}
3534
clients = OrderedDict()
3635

37-
response = await SmartExporterServiceStub([channel]).GetReport(empty_pb2.Empty())
36+
response = await SmartExporterStub([channel]).GetReport(empty_pb2.Empty())
3837

38+
channels = [channel]
3939
if use_alternative_endpoints:
4040
for endpoint in response.alternative_endpoints:
4141
if endpoint.certificate:
42-
attempted_channel = grpc.aio.secure_channel(
43-
endpoint.endpoint,
44-
grpc.ssl_channel_credentials(
45-
root_certificates=endpoint.certificate.encode(),
46-
private_key=endpoint.client_private_key.encode(),
47-
certificate_chain=endpoint.client_certificate.encode(),
48-
),
49-
)
50-
try:
51-
response = await jumpstarter_pb2_grpc.ExporterServiceStub(attempted_channel).GetReport(
52-
empty_pb2.Empty()
42+
channels.append(
43+
grpc.aio.secure_channel(
44+
endpoint.endpoint,
45+
grpc.ssl_channel_credentials(
46+
root_certificates=endpoint.certificate.encode(),
47+
private_key=endpoint.client_private_key.encode(),
48+
certificate_chain=endpoint.client_certificate.encode(),
49+
),
5350
)
54-
except Exception:
55-
pass # TODO: log failed attempt
56-
else:
57-
channel = attempted_channel
58-
break
59-
else:
60-
continue
51+
)
52+
53+
stub = SmartExporterStub(list(reversed(channels)))
6154

6255
for index, report in enumerate(response.reports):
6356
topo[index] = []
@@ -77,7 +70,7 @@ async def client_from_channel(
7770
client = client_class(
7871
uuid=UUID(report.uuid),
7972
labels=report.labels,
80-
channel=channel,
73+
stub=stub,
8174
portal=portal,
8275
stack=stack.enter_context(ExitStack()),
8376
children={reports[k].labels["jumpstarter.dev/name"]: clients[k] for k in topo[index]},

packages/jumpstarter/jumpstarter/client/core.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import logging
66
from contextlib import asynccontextmanager
77
from dataclasses import dataclass, field
8+
from typing import Any
89

910
from anyio import create_task_group
1011
from google.protobuf import empty_pb2
1112
from grpc import StatusCode
12-
from grpc.aio import AioRpcError, Channel
13+
from grpc.aio import AioRpcError
1314
from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc
1415

1516
from jumpstarter.common import Metadata
@@ -60,16 +61,14 @@ class AsyncDriverClient(
6061
Backing implementation of blocking driver client.
6162
"""
6263

63-
channel: Channel
64+
stub: Any
6465

6566
log_level: str = "INFO"
6667
logger: logging.Logger = field(init=False)
6768

6869
def __post_init__(self):
6970
if hasattr(super(), "__post_init__"):
7071
super().__post_init__()
71-
jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel)
72-
router_pb2_grpc.RouterServiceStub.__init__(self, self.channel)
7372
self.logger = logging.getLogger(self.__class__.__name__)
7473
self.logger.setLevel(self.log_level)
7574

@@ -89,7 +88,7 @@ async def call_async(self, method, *args):
8988
)
9089

9190
try:
92-
response = await self.DriverCall(request)
91+
response = await self.stub.DriverCall(request)
9392
except AioRpcError as e:
9493
match e.code():
9594
case StatusCode.UNIMPLEMENTED:
@@ -113,7 +112,7 @@ async def streamingcall_async(self, method, *args):
113112
)
114113

115114
try:
116-
async for response in self.StreamingDriverCall(request):
115+
async for response in self.stub.StreamingDriverCall(request):
117116
yield decode_value(response.result)
118117
except AioRpcError as e:
119118
match e.code():
@@ -128,7 +127,7 @@ async def streamingcall_async(self, method, *args):
128127

129128
@asynccontextmanager
130129
async def stream_async(self, method):
131-
context = self.Stream(
130+
context = self.stub.Stream(
132131
metadata=StreamRequestMetadata.model_construct(request=DriverStreamRequest(uuid=self.uuid, method=method))
133132
.model_dump(mode="json", round_trip=True)
134133
.items(),
@@ -142,7 +141,7 @@ async def resource_async(
142141
self,
143142
stream,
144143
):
145-
context = self.Stream(
144+
context = self.stub.Stream(
146145
metadata=StreamRequestMetadata.model_construct(request=ResourceStreamRequest(uuid=self.uuid))
147146
.model_dump(mode="json", round_trip=True)
148147
.items(),
@@ -160,7 +159,7 @@ def __log(self, level: int, msg: str):
160159
@asynccontextmanager
161160
async def log_stream_async(self):
162161
async def log_stream():
163-
async for response in self.LogStream(empty_pb2.Empty()):
162+
async for response in self.stub.LogStream(empty_pb2.Empty()):
164163
self.__log(logging.getLevelName(response.severity), response.message)
165164

166165
async with create_task_group() as tg:

packages/jumpstarter/jumpstarter/client/grpc.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from collections import OrderedDict
44
from dataclasses import InitVar, dataclass, field
55
from datetime import datetime, timedelta
6-
from functools import partial
7-
from typing import Generic, Type, TypeVar
6+
from types import SimpleNamespace
7+
from typing import Any
88

99
import yaml
1010
from google.protobuf import duration_pb2, field_mask_pb2, json_format
@@ -256,19 +256,18 @@ async def DeleteLease(self, *, name: str):
256256
)
257257

258258

259-
T = TypeVar("T")
260-
261-
262259
@dataclass(frozen=True, slots=True)
263-
class SmartStub(Generic[T]):
260+
class SmartExporterStub:
264261
channels: InitVar[list[Channel]]
265-
cls: InitVar[Type]
266262

267-
__stubs: dict[Channel, T] = field(init=False, default_factory=OrderedDict)
263+
__stubs: dict[Channel, Any] = field(init=False, default_factory=OrderedDict)
268264

269-
def __post_init__(self, channels, cls):
265+
def __post_init__(self, channels):
270266
for channel in channels:
271-
self.__stubs[channel] = cls(channel)
267+
stub = SimpleNamespace()
268+
jumpstarter_pb2_grpc.ExporterServiceStub.__init__(stub, channel)
269+
router_pb2_grpc.RouterServiceStub.__init__(stub, channel)
270+
self.__stubs[channel] = stub
272271

273272
def __getattr__(self, name):
274273
for channel, stub in self.__stubs.items():
@@ -277,13 +276,3 @@ def __getattr__(self, name):
277276
return getattr(stub, name)
278277
# or fallback to the last channel (via router)
279278
return getattr(next(reversed(self.__stubs.values())), name)
280-
281-
282-
SmartExporterServiceStub = partial(
283-
SmartStub[jumpstarter_pb2_grpc.ExporterServiceStub],
284-
cls=jumpstarter_pb2_grpc.ExporterServiceStub,
285-
)
286-
SmartRouterServiceStub = partial(
287-
SmartStub[router_pb2_grpc.RouterServiceStub],
288-
cls=router_pb2_grpc.RouterServiceStub,
289-
)

0 commit comments

Comments
 (0)