Skip to content

Commit 5451132

Browse files
authored
RSDK-2506 Sessions client (#301)
1 parent 7ae6a3c commit 5451132

File tree

3 files changed

+274
-0
lines changed

3 files changed

+274
-0
lines changed

src/viam/robot/client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from viam.resource.types import RESOURCE_TYPE_COMPONENT, RESOURCE_TYPE_SERVICE, Subtype
4242
from viam.rpc.dial import DialOptions, ViamChannel, dial
4343
from viam.services.service_base import ServiceBase
44+
from viam.sessions_client import SessionsClient
4445
from viam.utils import dict_to_struct
4546

4647
LOGGER = logging.getLogger(__name__)
@@ -93,6 +94,11 @@ class Options:
9394
The frequency (in seconds) at which to attempt to reconnect a disconnected robot. 0 (zero) signifies no reconnection attempts
9495
"""
9596

97+
disable_sessions: bool = False
98+
"""
99+
Whether sessions are disabled
100+
"""
101+
96102
@classmethod
97103
async def at_address(cls, address: str, options: Options) -> Self:
98104
"""Create a robot client that is connected to the robot at the provided address.
@@ -138,6 +144,7 @@ async def _with_channel(cls, channel: Union[Channel, ViamChannel], options: Opti
138144
else:
139145
self._channel = channel.channel
140146
self._viam_channel = channel
147+
141148
self._connected = True
142149
self._client = RobotServiceStub(self._channel)
143150
self._manager = ResourceManager()
@@ -146,6 +153,7 @@ async def _with_channel(cls, channel: Union[Channel, ViamChannel], options: Opti
146153
self._should_close_channel = close_channel
147154
self._options = options
148155
self._address = self._channel._path if self._channel._path else f"{self._channel._host}:{self._channel._port}"
156+
self._sessions_client = SessionsClient(self._channel, disabled=self._options.disable_sessions)
149157

150158
try:
151159
await self.refresh()
@@ -180,6 +188,7 @@ async def _with_channel(cls, channel: Union[Channel, ViamChannel], options: Opti
180188
_resource_names: List[ResourceName]
181189
_should_close_channel: bool
182190
_closed: bool = False
191+
_sessions_client: SessionsClient
183192

184193
async def refresh(self):
185194
"""
@@ -270,6 +279,8 @@ async def _check_connection(self, check_every: int, reconnect_every: int):
270279

271280
while not self._connected:
272281
try:
282+
self._sessions_client.reset()
283+
273284
channel = await dial(self._address, self._options.dial_options)
274285

275286
client: RobotServiceStub
@@ -286,12 +297,14 @@ async def _check_connection(self, check_every: int, reconnect_every: int):
286297
self._channel = channel.channel
287298
self._viam_channel = channel
288299
self._client = RobotServiceStub(self._channel)
300+
self._sessions_client = SessionsClient(channel=self._channel, disabled=self._options.disable_sessions)
289301

290302
await self.refresh()
291303
self._connected = True
292304
LOGGER.debug("Successfully reconnected robot")
293305
except Exception as e:
294306
LOGGER.error(f"Failed to reconnect, trying again in {reconnect_every}sec", exc_info=e)
307+
self._sessions_client.reset()
295308
self._close_channel()
296309
await asyncio.sleep(reconnect_every)
297310

@@ -423,6 +436,8 @@ async def close(self):
423436
except RuntimeError:
424437
pass
425438

439+
self._sessions_client.reset()
440+
426441
# Cancel all tasks created by VIAM
427442
LOGGER.debug("Closing tasks spawned by Viam")
428443
tasks = [task for task in asyncio.all_tasks() if task.get_name().startswith(viam._TASK_PREFIX)]

src/viam/sessions_client.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import asyncio
2+
from datetime import timedelta
3+
from typing import Optional
4+
5+
from grpclib import Status
6+
from grpclib.client import Channel
7+
from grpclib.events import RecvTrailingMetadata, SendRequest, listen
8+
from grpclib.exceptions import GRPCError, StreamTerminatedError
9+
from grpclib.metadata import _MetadataLike
10+
11+
from viam import _TASK_PREFIX, logging
12+
from viam.proto.robot import RobotServiceStub, SendSessionHeartbeatRequest, StartSessionRequest, StartSessionResponse
13+
14+
LOGGER = logging.getLogger(__name__)
15+
SESSION_METADATA_KEY = "viam-sid"
16+
17+
EXEMPT_METADATA_METHODS = frozenset(
18+
[
19+
"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
20+
"/proto.rpc.webrtc.v1.SignalingService/Call",
21+
"/proto.rpc.webrtc.v1.SignalingService/CallUpdate",
22+
"/proto.rpc.webrtc.v1.SignalingService/OptionalWebRTCConfig",
23+
"/proto.rpc.v1.AuthService/Authenticate",
24+
"/viam.robot.v1.RobotService/ResourceNames",
25+
"/viam.robot.v1.RobotService/ResourceRPCSubtypes",
26+
"/viam.robot.v1.RobotService/StartSession",
27+
"/viam.robot.v1.RobotService/SendSessionHeartbeat",
28+
]
29+
)
30+
31+
32+
async def delay(coro, seconds):
33+
await asyncio.sleep(seconds)
34+
await coro
35+
36+
37+
class SessionsClient:
38+
"""
39+
A Session allows a client to express that it is actively connected and
40+
supports stopping actuating components when it's not.
41+
"""
42+
43+
_current_id: str = ""
44+
_disabled: bool = False
45+
_lock = asyncio.Lock()
46+
_supported: Optional[bool] = None
47+
_heartbeat_interval: Optional[timedelta] = None
48+
49+
def __init__(self, channel: Channel, *, disabled: bool = False):
50+
self.channel = channel
51+
self.client = RobotServiceStub(channel)
52+
self._disabled = disabled
53+
54+
listen(self.channel, SendRequest, self._send_request)
55+
listen(self.channel, RecvTrailingMetadata, self._recv_trailers)
56+
57+
def reset(self):
58+
if self._lock.locked():
59+
return
60+
61+
LOGGER.debug("resetting session")
62+
self._supported = None
63+
64+
async def _send_request(self, event: SendRequest):
65+
if self._disabled:
66+
return
67+
68+
if event.method_name in EXEMPT_METADATA_METHODS:
69+
return
70+
71+
event.metadata.update(await self.metadata)
72+
73+
async def _recv_trailers(self, event: RecvTrailingMetadata):
74+
if event.status == Status.INVALID_ARGUMENT and event.status_message == "SESSION_EXPIRED":
75+
LOGGER.debug("Session expired")
76+
self.reset()
77+
78+
@property
79+
async def metadata(self) -> _MetadataLike:
80+
if self._disabled:
81+
return self._metadata
82+
83+
if self._supported:
84+
return self._metadata
85+
86+
async with self._lock:
87+
if self._supported is False:
88+
return self._metadata
89+
90+
request = StartSessionRequest(resume=self._current_id)
91+
response: Optional[StartSessionResponse] = None
92+
93+
try:
94+
response = await self.client.StartSession(request)
95+
except GRPCError as error:
96+
if error.status == Status.UNIMPLEMENTED:
97+
self._supported = False
98+
return self._metadata
99+
else:
100+
raise
101+
else:
102+
if response is None:
103+
raise GRPCError(status=Status.INTERNAL, message="Expected response to start session")
104+
105+
if response.heartbeat_window is None:
106+
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")
107+
108+
self._supported = True
109+
self._heartbeat_interval = response.heartbeat_window.ToTimedelta()
110+
self._current_id = response.id
111+
112+
await self._heartbeat_tick()
113+
114+
return self._metadata
115+
116+
async def _heartbeat_tick(self):
117+
if not self._supported:
118+
return
119+
120+
while self._lock.locked():
121+
pass
122+
123+
request = SendSessionHeartbeatRequest(id=self._current_id)
124+
125+
if self._heartbeat_interval is None:
126+
raise GRPCError(status=Status.INTERNAL, message="Expected heartbeat window in response to start session")
127+
128+
try:
129+
await self.client.SendSessionHeartbeat(request)
130+
except (GRPCError, StreamTerminatedError):
131+
LOGGER.debug("Heartbeat terminated", exc_info=True)
132+
self.reset()
133+
else:
134+
LOGGER.debug("Sent heartbeat successfully")
135+
# We send heartbeats slightly faster than the interval window to
136+
# ensure that we don't fall outside of it and expire the session.
137+
wait = self._heartbeat_interval.total_seconds() / 5
138+
asyncio.create_task(delay(self._heartbeat_tick(), wait), name=f"{_TASK_PREFIX}-heartbeat")
139+
140+
@property
141+
def _metadata(self) -> _MetadataLike:
142+
if self._supported and self._current_id != "":
143+
return {SESSION_METADATA_KEY: self._current_id}
144+
145+
return {}

tests/test_sessions_client.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from datetime import timedelta
2+
3+
import pytest
4+
from google.protobuf.duration_pb2 import Duration
5+
from grpclib import GRPCError, Status
6+
from grpclib.server import Stream
7+
from grpclib.testing import ChannelFor
8+
9+
from viam.proto.robot import SendSessionHeartbeatRequest, SendSessionHeartbeatResponse, StartSessionRequest, StartSessionResponse
10+
from viam.resource.manager import ResourceManager
11+
from viam.robot.service import RobotService
12+
from viam.sessions_client import SESSION_METADATA_KEY, SessionsClient
13+
14+
SESSION_ID = "sid"
15+
HEARTBEAT_INTERVAL = 2
16+
17+
18+
@pytest.fixture(scope="function")
19+
def service() -> RobotService:
20+
async def StartSession(stream: Stream[StartSessionRequest, StartSessionResponse]) -> None:
21+
request = await stream.recv_message()
22+
assert request is not None
23+
heartbeat_window = Duration()
24+
heartbeat_window.FromTimedelta(timedelta(seconds=HEARTBEAT_INTERVAL))
25+
response = StartSessionResponse(id=SESSION_ID, heartbeat_window=heartbeat_window)
26+
await stream.send_message(response)
27+
28+
async def SendSessionHeartbeat(stream: Stream[SendSessionHeartbeatRequest, SendSessionHeartbeatResponse]) -> None:
29+
request = await stream.recv_message()
30+
assert request is not None
31+
response = SendSessionHeartbeatResponse()
32+
await stream.send_message(response)
33+
34+
manager = ResourceManager([])
35+
service = RobotService(manager)
36+
service.StartSession = StartSession
37+
service.SendSessionHeartbeat = SendSessionHeartbeat
38+
39+
return service
40+
41+
42+
@pytest.fixture(scope="function")
43+
def service_without_session(service: RobotService) -> RobotService:
44+
del service.StartSession
45+
return service
46+
47+
48+
@pytest.fixture(scope="function")
49+
def service_without_heartbeat(service: RobotService) -> RobotService:
50+
del service.SendSessionHeartbeat
51+
return service
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_init_client():
56+
async with ChannelFor([]) as channel:
57+
client = SessionsClient(channel)
58+
assert client._current_id == ""
59+
assert client._supported is None
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_sessions_error():
64+
async with ChannelFor([]) as channel:
65+
client = SessionsClient(channel)
66+
67+
with pytest.raises(GRPCError) as e_info:
68+
assert await client.metadata == {}
69+
70+
assert e_info.value.status == Status.UNKNOWN
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_sessions_not_supported():
75+
async with ChannelFor([]) as channel:
76+
client = SessionsClient(channel)
77+
client._supported = False
78+
assert await client.metadata == {}
79+
assert client._supported is False
80+
81+
82+
@pytest.mark.asyncio
83+
async def test_sessions_not_implemented(service_without_session: RobotService):
84+
async with ChannelFor([service_without_session]) as channel:
85+
client = SessionsClient(channel)
86+
assert await client.metadata == {}
87+
assert client._supported is False
88+
89+
90+
@pytest.mark.asyncio
91+
async def test_sessions_heartbeat_disconnect(service_without_heartbeat: RobotService):
92+
async with ChannelFor([service_without_heartbeat]) as channel:
93+
client = SessionsClient(channel)
94+
assert await client.metadata == {}
95+
assert client._supported is None
96+
97+
98+
@pytest.mark.asyncio
99+
async def test_sessions_heartbeat(service: RobotService):
100+
async with ChannelFor([service]) as channel:
101+
client = SessionsClient(channel)
102+
assert await client.metadata == {SESSION_METADATA_KEY: SESSION_ID}
103+
assert client._supported
104+
assert client._heartbeat_interval and client._heartbeat_interval.total_seconds() == HEARTBEAT_INTERVAL
105+
assert client._current_id == SESSION_ID
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_sessions_disabled(service: RobotService):
110+
async with ChannelFor([service]) as channel:
111+
client = SessionsClient(channel, disabled=True)
112+
assert await client.metadata == {}
113+
assert client._supported is None
114+
assert not client._heartbeat_interval

0 commit comments

Comments
 (0)