Skip to content

Commit 71db42d

Browse files
committed
Drop messages on the client for closed streams/subscriptions
1 parent 7d011c7 commit 71db42d

File tree

5 files changed

+73
-6
lines changed

5 files changed

+73
-6
lines changed

replit_river/client_session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ async def send_subscription(
234234
) from e
235235
except Exception as e:
236236
raise e
237+
finally:
238+
output.close()
237239

238240
async def send_stream(
239241
self,
@@ -335,6 +337,8 @@ async def _encode_stream() -> None:
335337
) from e
336338
except Exception as e:
337339
raise e
340+
finally:
341+
output.close()
338342

339343
async def send_close_stream(
340344
self,

replit_river/session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,11 @@ async def _add_msg_to_stream(
525525
return
526526
try:
527527
await stream.put(msg.payload)
528-
except (RuntimeError, ChannelClosed) as e:
528+
except ChannelClosed:
529+
# The client is no longer interested in this stream,
530+
# just drop the message.
531+
pass
532+
except RuntimeError as e:
529533
raise InvalidMessageException(e) from e
530534

531535
async def _remove_acked_messages_in_buffer(self) -> None:

tests/common_handlers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ async def upload_handler(
3939

4040
basic_upload: HandlerMapping = {
4141
("test_service", "upload_method"): (
42-
"upload",
42+
"upload-stream",
4343
upload_method_handler(upload_handler, deserialize_request, serialize_response),
4444
),
4545
}
@@ -54,7 +54,7 @@ async def subscription_handler(
5454

5555
basic_subscription: HandlerMapping = {
5656
("test_service", "subscription_method"): (
57-
"subscription",
57+
"subscription-stream",
5858
subscription_method_handler(
5959
subscription_handler, deserialize_request, serialize_response
6060
),

tests/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Mapping
1+
from typing import Any, Literal, Mapping
22

33
import nanoid
44
import pytest
@@ -16,7 +16,8 @@
1616
# Modular fixtures
1717
pytest_plugins = ["tests.river_fixtures.logging", "tests.river_fixtures.clientserver"]
1818

19-
HandlerMapping = Mapping[tuple[str, str], tuple[str, GenericRpcHandler]]
19+
HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"]
20+
HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandler]]
2021

2122

2223
def transport_message(

tests/test_communication.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import asyncio
22
from datetime import timedelta
3+
import logging
34
from typing import AsyncGenerator
45

56
import pytest
7+
from grpc.aio import grpc
68

79
from replit_river.client import Client
810
from replit_river.error_schema import RiverError
11+
from replit_river.rpc import subscription_method_handler
912
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
1013
from tests.common_handlers import (
1114
basic_rpc_method,
@@ -14,9 +17,12 @@
1417
basic_upload,
1518
)
1619
from tests.conftest import (
20+
HandlerMapping,
1721
deserialize_error,
22+
deserialize_request,
1823
deserialize_response,
1924
serialize_request,
25+
serialize_response,
2026
)
2127

2228

@@ -101,6 +107,7 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
101107
@pytest.mark.asyncio
102108
@pytest.mark.parametrize("handlers", [{**basic_subscription}])
103109
async def test_subscription_method(client: Client) -> None:
110+
messages = []
104111
async for response in client.send_subscription(
105112
"test_service",
106113
"subscription_method",
@@ -110,7 +117,8 @@ async def test_subscription_method(client: Client) -> None:
110117
deserialize_error,
111118
):
112119
assert isinstance(response, str)
113-
assert "Subscription message" in response
120+
messages.append(response)
121+
assert messages == [f"Subscription message {i} for Bob" for i in range(5)]
114122

115123

116124
@pytest.mark.asyncio
@@ -213,3 +221,53 @@ async def stream_data() -> AsyncGenerator[str, None]:
213221
"Stream response for Stream Data 1",
214222
"Stream response for Stream Data 2",
215223
]
224+
225+
226+
async def flood_subscription_handler(
227+
request: str, context: grpc.aio.ServicerContext
228+
) -> AsyncGenerator[str, None]:
229+
for i in range(128 * 2):
230+
logging.warning(f"sending {i}")
231+
yield f"Subscription message {i} for {request}"
232+
233+
234+
flood_subscription: HandlerMapping = {
235+
("test_service", "flood_subscription_method"): (
236+
"subscription-stream",
237+
subscription_method_handler(
238+
flood_subscription_handler, deserialize_request, serialize_response
239+
),
240+
),
241+
}
242+
243+
244+
@pytest.mark.asyncio
245+
@pytest.mark.parametrize("handlers", [{**basic_rpc_method, **flood_subscription}])
246+
async def test_ignore_flood_subscription(client: Client) -> None:
247+
# Intentionally don't read from the subscription.
248+
sub = client.send_subscription(
249+
"test_service",
250+
"flood_subscription_method",
251+
"Initial Subscription Data",
252+
serialize_request,
253+
deserialize_response,
254+
deserialize_error,
255+
)
256+
257+
# read one entry to start the subscription
258+
await sub.__anext__()
259+
# close the subscription so we can signal that we're not
260+
# interested in the rest of the subscription.
261+
await sub.aclose()
262+
263+
# ensure that subsequent RPCs still work
264+
response = await client.send_rpc(
265+
"test_service",
266+
"rpc_method",
267+
"Alice",
268+
serialize_request,
269+
deserialize_response,
270+
deserialize_error,
271+
timedelta(seconds=20),
272+
)
273+
assert response == "Hello, Alice!"

0 commit comments

Comments
 (0)